51 Commits

Author SHA1 Message Date
Vincent Amato
13c3892aae Merge 8e293bbc51 into 4b2a0df237 2025-08-16 16:00:21 -04:00
Vincent Amato
8e293bbc51 Add ESM 2025-08-16 15:59:51 -04:00
Shashank
4b2a0df237 adding wwdc25 samples (#1370) 2025-06-10 10:23:25 -07:00
Denrei Keith
977cd30242 Update lora README.md (#1365)
point to the correct repository

https://github.com/ml-explore/mlx-lm
2025-05-01 06:00:14 -07:00
Param Thakkar
4c9f9f9be7 Made llama and mistral files mypy compatible (#1359)
* Made mypy compatible

* reformatted

* Added more fixes

* Added fixes to speculative-decoding

* Fixes

* fix circle

* revert some stuff

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-04-23 14:23:46 -07:00
Angelos Katharopoulos
c52cc748f8 Distributed FLUX (#1325) 2025-03-24 22:16:48 -07:00
Awni Hannun
c243370044 remove mlx lm (#1353) 2025-03-18 18:47:55 -07:00
Tingzhen
7ca05d2e51 LoRa/README.md should be --hf-path instead of --hf-repo (#1350)
Co-authored-by: du tingzhen <dutingzhen@macbookpro.myfiosgateway.com>
2025-03-16 20:02:52 -07:00
Awni Hannun
d9e1d9c0ef mlx-lm move notice (#1346)
* mlx-lm move notice

* remove mlx lm tests
2025-03-16 15:14:28 -07:00
Prince Canuma
2fce02acd8 Add support for Gemma3 (#1336)
* add support for gemma3

* fix model loading

* revert rmsnorm

* revert is sliding pattern

* revert

* add tests

* formatting

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update llms/mlx_lm/models/gemma3_text.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix sliding window mask

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-03-13 08:14:25 -07:00
Mirko Nasato
3e5baf583b Make sure to use UTF-8 when loading tokenizer.json (#1340) 2025-03-12 19:17:14 -07:00
Neil Mehta
4c3df00162 make_sampler creates sampler chain with all sampling parameters (#1330)
* top_p refactor

* top_k and min_p refactor

* Create sampler chain

* Remove unnecessary mx.where

* Use mx.allclose
2025-03-11 13:37:35 -07:00
Awni Hannun
d2e02b3aae fix mixed quant option (#1326) 2025-03-07 08:35:48 -08:00
Awni Hannun
595f5da146 remove lm head if unused (#1324) 2025-03-06 15:35:47 -08:00
cavit99
877d2a345b Change DEFAULT_SEED to None for stochastic generation by default (#1323)
* Change DEFAULT_SEED to None for stochastic generation by default

* Update llms/mlx_lm/chat.py

* Update llms/mlx_lm/generate.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-03-06 06:49:35 -08:00
Awni Hannun
32d10036de fix flaky test (#1322) 2025-03-05 14:00:09 -08:00
Gökdeniz Gülmez
e150621095 Adding multiple optimizers to mlx lm (#1315)
* initial commmit

* adding more customized YAML configuartion

* update YAML example file

* Changed the switch to set opt_class

* removing muon

* using default arguments

* udpate
2025-03-05 13:54:54 -08:00
Gökdeniz Gülmez
56d2db23e1 adding OLMoE architecture (#1321)
* initial commit

* udpate ACKNOWLEDGMENTS.md

* adding olmoe to training

* clean up

* faster generation

* remove sanitize method

* more clean ups

* adding SwitchGLU

* clean up

* a little faster and adding norm_topk_prob

* formated
2025-03-05 13:46:06 -08:00
Angelos Katharopoulos
e7267d30f8 Distributed support cifar (#1301) 2025-03-05 13:33:15 -08:00
Awni Hannun
f621218ff5 Tool use example (#1316)
* tool use example

* nits
2025-03-04 13:53:20 -08:00
Awni Hannun
65aa2ec849 use a bool mask for attention (#1319) 2025-03-04 12:47:32 -08:00
Pierre-Louis
1bc3476a46 chore(lora): Add real-time log buffering fix for nohup execution (#1311)
* chore(lora): Add real-time log buffering fix for nohup execution

Disable Python stdout buffering to ensure logs appear in nohup.out in real-time instead of only after script completion.

* chore(lora): remove python 3.7+ check

* chore(lora): running pre-commit hook

---------

Co-authored-by: Pierre-Louis Létoquart <randlgint@proton.me>
2025-03-03 06:12:33 -08:00
Shunta Saito
269faa5fa4 Fix plamo2 model to use rms_norm (#1308)
* Fix plamo2 model to use rms_norm and enable sliding window attention

* Fix missing variable

* Remove sliding window attention impl. cause it should be done by using RotatingKVCache

* Remove unused imports
2025-03-03 06:12:02 -08:00
Awni Hannun
845cd8c01e support kimi + more options in chat mode (#1312) 2025-02-28 11:33:18 -08:00
Awni Hannun
b2108a0de6 Allow mask prompt in config (#1314) 2025-02-28 11:33:04 -08:00
madroid
eb73549631 Generate: Support Prefill Response (#1299)
* Generate: Support Prefill Prompt

python -m mlx_lm.generate \
       --model mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-4bit \
       --prompt "hello" \
       --prefill-prompt "<think>\n"

* Generate: rename prefill-prompt to prefill-response

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-02-27 07:44:00 -08:00
Awni Hannun
00a7379070 Fixes for phi4 mini (#1305) 2025-02-26 16:21:54 -08:00
Awni Hannun
0f240a4c7e Use max tokens from options in mlx_lm evaluate (#1302) 2025-02-26 15:46:16 -08:00
Awni Hannun
56e60ad5a6 fix manage for new transformers (#1304) 2025-02-26 15:44:57 -08:00
Pedro Cuenca
b7f742ef56 Mixed quant recipes (#1300)
* Mixed 3/6 and 2/6 recipes based on Alex Barron's

* format / nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-02-26 11:32:36 -08:00
Shunta Saito
c37e26a1a3 Add plamo-2-1b model (#1283)
* Add pfnet/plamo-2-1b

* Fix cache.py to support non-top level layers

* Use mlx's BaseModelArgs

* Fix model

* Use sanitize()

* Remove unnecessary changes

* Add plamo2.py

* Apply formatter

* Fix some part

* Allow a cache obj defined externally

* Fix channel first weights to channel last for right use of MLX's conv1d

* Remove unused code part

* Give all inputs when it's the first time call of model

* Fix import

* Include .jsonl files to download from Huggingface hub

* Fix reference to layers

* Remove unnecessary code and add a test for plamo2

* Do not pass mask to prepare_inputs_for_generation

* Fix to use repeat instead of tile

* Add state property to PlamoCache

* Add __iter__ and __next__ methods to PlamoCache

* cleanup

* cleanup

* fix

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-02-24 19:24:43 -08:00
Usama Ahmed
09b641aaa7 Fix FutureWarning in torch.load by setting weights_only=True (#1295) 2025-02-22 06:08:54 -08:00
Awni Hannun
3d793ecf68 Fix logits processor bugs with spec dec (#1291)
* Fix logits processor bugs with spec dec

* bump patch
2025-02-20 15:55:55 -08:00
Awni Hannun
85669451d0 Fix num layers in fine tune (#1294) 2025-02-20 13:32:01 -08:00
Awni Hannun
1cbf5cdac7 use more standard window strategy (#1287) 2025-02-19 06:22:51 -08:00
Matthias Neumayer
96bf37008e Update README.md to include how to set temperature (#1280)
* Update README.md to include how to set temperature

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-02-13 19:32:56 -08:00
Awni Hannun
7b07b14e67 add logits processor to spec gen (#1260) 2025-02-13 19:19:53 -08:00
Awni Hannun
ec30dc3538 hunyuan finetune (#1270) 2025-02-11 16:49:35 -08:00
Awni Hannun
42413c5d85 fix lora timings after validation (#1278) 2025-02-11 16:48:55 -08:00
Awni Hannun
f8cbf159e0 fix sharding for more even number of layers (#1276) 2025-02-11 16:26:59 -08:00
Awni Hannun
e879ea70e1 fix generation evaluations (#1277) 2025-02-11 16:10:30 -08:00
Matt Clayton
3d677f0870 Add "from_draft" to GenerationResponse (#1272)
* Add from_draft field in GenerationResponse

* Cleanup

* Re-work for minimal changes, add test

* Fix comment
2025-02-11 15:41:02 -08:00
Awni Hannun
bded1a8fcd fix looping in whisper (#1273) 2025-02-10 13:04:35 -08:00
Chime Ogbuji
5865899c81 Completion only fine-tuning of instruction models with collections of HF datasets (#1103)
- Optional completion only fine-tuning with `--mask-prompt`
- Collections of Hugging Face datasets

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-02-09 20:12:34 -08:00
Sri Harsha Pamu
1ced1b00ca rm temp argument (#1267) 2025-02-09 11:39:11 -08:00
Awni Hannun
f58c7de901 Some improvements to speedup alignment computation in MLX Whisper (#1259)
* some improvements to speedup alignment computation in MLX Whisper

* fix alignment
2025-02-08 15:47:00 -08:00
Awni Hannun
1503bd4f55 support hunyuan 7b (#1263) 2025-02-08 15:46:47 -08:00
Awni Hannun
31611b62d7 Add IBM granite model (#1265)
* add granite

* add thinking option
2025-02-08 15:46:15 -08:00
Awni Hannun
6120a5f376 Faster DSv2/3 expert score computation (#1257)
* fix deepseek sharding (#1242)

* compile and use put along axis in deep seek routing function
2025-02-07 10:24:57 -08:00
Awni Hannun
52c41b5b5a Fix prompt cache for models without chat template (#1250)
* fix deepseek sharding (#1242)

* fix prompt cache with no chat template
2025-02-06 11:10:58 -08:00
Nripesh Niketan
747c08e202 Chore: pre-commit bump (#1253) 2025-02-06 09:06:31 -08:00
152 changed files with 9817 additions and 19015 deletions

View File

@@ -17,30 +17,6 @@ jobs:
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
mlx_lm_build_and_test:
macos:
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install unittest-xml-reporting
cd llms/
pip install -e ".[test]"
- run:
name: Run Python tests
command: |
source env/bin/activate
python -m xmlrunner discover -v llms/tests -o test-results/
- store_test_results:
path: test-results
workflows:
build_and_test:
when:
@@ -48,7 +24,6 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
jobs:
- mlx_lm_build_and_test
- linux_build_and_test
prb:
@@ -61,7 +36,5 @@ workflows:
type: approval
- apple/authenticate:
context: pr-approval
- mlx_lm_build_and_test:
requires: [ hold ]
- linux_build_and_test:
requires: [ hold ]

View File

@@ -1,10 +1,10 @@
repos:
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 6.0.0
hooks:
- id: isort
args:

View File

@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`.

View File

@@ -4,12 +4,12 @@ This repo contains a variety of standalone examples using the [MLX
framework](https://github.com/ml-explore/mlx).
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
Some more useful examples are listed below.
Some more useful examples are listed below. Check-out [MLX
LM](https://github.com/ml-explore/mlx-lm) for a more fully featured Python
package for LLMs with MLX.
### Text Models
- [MLX LM](llms/README.md) a package for LLM text generation, fine-tuning, and more.
- [Transformer language model](transformer_lm) training.
- Minimal examples of large scale text generation with [LLaMA](llms/llama),
[Mistral](llms/mistral), and more in the [LLMs](llms) directory.
@@ -30,6 +30,7 @@ Some more useful examples are listed below.
- Speech recognition with [OpenAI's Whisper](whisper).
- Audio compression and generation with [Meta's EnCodec](encodec).
- Music generation with [Meta's MusicGen](musicgen).
### Multimodal models

View File

@@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added.
## Distributed training
The example also supports distributed data parallel training. You can launch a
distributed training as follows:
```shell
$ cat >hostfile.json
[
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
]
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
```

View File

@@ -1,3 +1,4 @@
import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10
@@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
x = x.astype("float32") / 255.0
return (x - mean) / std
group = mx.distributed.init()
tr_iter = (
tr.shuffle()
.partition_if(group.size() > 1, group.size(), group.rank())
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
@@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
)
test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
test_iter = (
test.to_stream()
.partition_if(group.size() > 1, group.size(), group.rank())
.key_transform("image", normalize)
.batch(batch_size)
)
return tr_iter, test_iter

View File

@@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")
def print_zero(group, *args, **kwargs):
if group.rank() != 0:
return
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
@@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch):
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
losses = []
accs = []
samples_per_sec = []
world = mx.distributed.init()
losses = 0
accuracies = 0
samples_per_sec = 0
count = 0
def average_stats(stats, count):
if world.size() == 1:
return [s / count for s in stats]
with mx.stream(mx.cpu):
stats = mx.distributed.all_sum(mx.array(stats))
count = mx.distributed.all_sum(count)
return (stats / count).tolist()
state = [model.state, optimizer.state]
@@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
def step(inp, tgt):
train_step_fn = nn.value_and_grad(model, train_step)
(loss, acc), grads = train_step_fn(model, inp, tgt)
grads = nn.utils.average_gradients(grads)
optimizer.update(model, grads)
return loss, acc
@@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch):
y = mx.array(batch["label"])
tic = time.perf_counter()
loss, acc = step(x, y)
mx.eval(state)
mx.eval(loss, acc, state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()
losses.append(loss)
accs.append(acc)
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
losses += loss.item()
accuracies += acc.item()
samples_per_sec += x.shape[0] / (toc - tic)
count += 1
if batch_counter % 10 == 0:
print(
l, a, s = average_stats(
[losses, accuracies, world.size() * samples_per_sec],
count,
)
print_zero(
world,
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {loss:.3f}",
f"Train acc {acc:.3f}",
f"Throughput: {throughput:.2f} images/second",
f"Train loss {l:.3f}",
f"Train acc {a:.3f}",
f"Throughput: {s:.2f} images/second",
)
)
),
)
mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
def test_epoch(model, test_iter, epoch):
accs = []
accuracies = 0
count = 0
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
acc = eval_fn(model, x, y)
acc_value = acc.item()
accs.append(acc_value)
mean_acc = mx.mean(mx.array(accs))
return mean_acc
accuracies += acc.item()
count += 1
with mx.stream(mx.cpu):
accuracies = mx.distributed.all_sum(accuracies)
count = mx.distributed.all_sum(count)
return (accuracies / count).item()
def main(args):
mx.random.seed(args.seed)
# Initialize the distributed group and report the nodes that showed up
world = mx.distributed.init()
if world.size() > 1:
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
model = getattr(resnet, args.arch)()
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
optimizer = optim.Adam(learning_rate=args.lr)
train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print(
print_zero(
world,
" | ".join(
(
f"Epoch: {epoch}",
f"avg. Train loss {tr_loss.item():.3f}",
f"avg. Train acc {tr_acc.item():.3f}",
f"Throughput: {throughput.item():.2f} images/sec",
f"avg. Train loss {tr_loss:.3f}",
f"avg. Train acc {tr_acc:.3f}",
f"Throughput: {throughput:.2f} images/sec",
)
)
),
)
test_acc = test_epoch(model, test_data, epoch)
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
train_data.reset()
test_data.reset()

View File

@@ -121,7 +121,7 @@ if __name__ == "__main__":
mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading")
torch_weights = torch.load(torch_path / "pytorch_model.bin")
torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
print("[INFO] Converting")
mlx_weights = {
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()

156
esm/README.md Normal file
View File

@@ -0,0 +1,156 @@
# ESM-2
This repository provides an implementation of Meta's ESM-2 protein language model
in MLX.[^1] ESM-2 is Metas second-generation Evolutionary Scale Model, a
transformer-based protein language model trained on millions of diverse protein
sequences with a masked language modeling objective.
![Example contact prediction map](assets/contact_prediction.png)
_Example contact prediction map for a universal stress protein. In this case, ESM-2 650M achieves 86.4% precision at long-range contacts._
## Setup
Install the requirements:
```bash
pip install -r requirements.txt
```
## Usage
Below are the available ESM-2 models:
| Model | Parameters | Layers |
|-------|------------|--------|
| [`esm2_t6_8M_UR50D`](https://huggingface.co/facebook/esm2_t6_8M_UR50D) | 8M | 6 |
| [`esm2_t12_35M_UR50D`](https://huggingface.co/facebook/esm2_t12_35M_UR50D) | 35M | 12 |
| [`esm2_t30_150M_UR50D`](https://huggingface.co/facebook/esm2_t30_150M_UR50D) | 150M | 30 |
| [`esm2_t33_650M_UR50D`](https://huggingface.co/facebook/esm2_t33_650M_UR50D) | 650M | 33 |
| [`esm2_t36_3B_UR50D`](https://huggingface.co/facebook/esm2_t36_3B_UR50D) | 3B | 36 |
| [`esm2_t48_15B_UR50D`](https://huggingface.co/facebook/esm2_t48_15B_UR50D) | 15B | 48 |
Convert a model to MLX format:
```bash
python convert.py --hf-path facebook/esm2_t33_650M_UR50D
```
This will save the converted model in a checkpoints directory.
### Basic Inference
```python
from esm import ESM2
# Load model and tokenizer
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
# Example protein sequence (human insulin)
sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
# Tokenize and run inference
tokens = tokenizer.encode(sequence)
result = model(tokens)
logits = result["logits"] # Shape: (batch, length, vocab_size)
```
### Masked Language Modeling
```bash
# For a complete example, see main.py
python main.py --sequence "YOUR_SEQUENCE" --mask-position 50
```
### Embeddings
```python
# Get sequence-level representations
seq_repr = model.get_sequence_representations(tokens, layer=-1) # Shape: (batch, embed_dim)
# Extract per-residue representations from specific layers
representations = model.extract_features(tokens, repr_layers=[20, 30, 33])
final_layer = representations[33] # Shape: (batch, length, embed_dim)
```
### Contact Prediction
```python
# Predict residue-residue contacts
contacts = model.predict_contacts(tokens) # Shape: (batch, length, length)
# Or compute contacts together with logits, representations, etc.
outputs = model(tokens, return_contacts=True)
contacts = outputs["contacts"]
```
### Examples
**Mutation Effect Prediction**: [notebooks/mutation_effect_prediction.ipynb](notebooks/mutation_effect_prediction.ipynb)
This notebook demonstrates how to use ESM-2 for zero-shot mutation effect prediction by scoring amino acid substitutions based on their likelihood under the model. We validate the approach using experimental fitness data from β-lactamase TEM, showing how ESM-2 captures functional constraints without requiring structural information.
**Embeddings**: [notebooks/embeddings.ipynb](notebooks/embeddings.ipynb)
This notebook explores how ESM-2 generates meaningful protein embeddings that capture evolutionary and functional relationships between proteins. We analyze six diverse human proteins to demonstrate how the learned representations cluster proteins by function and reveal biological similarities.
**Contact Prediction**: [notebooks/contact_prediction.ipynb](notebooks/contact_prediction.ipynb)
This notebook shows how to predict residue-residue contacts in protein structures using ESM-2's attention patterns. We evaluate contact prediction performance on three diverse proteins, demonstrating how the model captures both local and long-range structural relationships directly from sequence data.
### Benchmarking
Benchmark MLX performance:
```bash
python benchmarks/benchmark_mx.py
```
Benchmark PyTorch MPS performance:
```bash
python benchmarks/benchmark_pt.py
```
Expected performance on M4 MacBook Pro (ESM-2 650M, batch_size = 5):
- MLX: 299 ms per step, 16.71 sequences/sec
- PyTorch MPS: 402 ms per step, 12.43 sequences/sec
### Testing
Verify correctness against original implementation:
```bash
python test.py
```
This tests tokenizer and model outputs (logits, hidden states, and attentions) for equivalence with the original implementation.
### Citations:
```bibtex
@article{rives2019biological,
author={Rives, Alexander and Meier, Joshua and Sercu, Tom and Goyal, Siddharth and Lin, Zeming and Liu, Jason and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
title={Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
year={2019},
doi={10.1101/622803},
url={https://www.biorxiv.org/content/10.1101/622803v4},
journal={PNAS}
}
```
```bibtex
@article{Lin2023,
author={Zeming Lin et al.},
title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
journal={Science},
volume={379},
pages={1123--1130},
year={2023},
doi={10.1126/science.ade2574},
url={https://doi.org/10.1126/science.ade2574}
}
```
[^1]: Refer to the [paper](https://www.science.org/doi/10.1126/science.ade2574) and [code](https://github.com/facebookresearch/esm) for more details.

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

@@ -0,0 +1,47 @@
import sys
import time
from pathlib import Path
import mlx.core as mx
# Add parent directory to Python path
cur_path = Path(__file__).parents[1].resolve()
sys.path.append(str(cur_path))
from esm import ESM2
# Example protein sequence (Green Fluorescent Protein)
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
# Load pretrained ESM-2 model and its tokenizer from local checkpoint
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
# Number of sequences to process in each forward pass
batch_size = 5
# Number of timing iterations for performance measurement
steps = 50
# Tokenize the protein sequence into integer IDs for the model
# Replicate the same sequence 'batch_size' times to create a batch
tokens = tokenizer.batch_encode([protein_sequence] * batch_size)
# Warm-up phase
for _ in range(10):
result = model(tokens)
mx.eval(result["logits"]) # Force computation to complete
# Measure average inference time over 'steps' iterations
tic = time.time()
for _ in range(steps):
result = model(tokens)
mx.eval(result["logits"]) # Synchronize and ensure computation finishes
toc = time.time()
# Compute metrics: average time per step (ms) and throughput (sequences/sec)
ms_per_step = 1000 * (toc - tic) / steps
throughput = batch_size * 1000 / ms_per_step
# Display results
print(f"Time (ms) per step: {ms_per_step:.3f}")
print(f"Throughput: {throughput:.2f} sequences/sec")

View File

@@ -0,0 +1,52 @@
import time
import torch
from transformers import AutoTokenizer, EsmForMaskedLM
# Example protein sequence (Green Fluorescent Protein)
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
# Hugging Face model identifier for ESM-2 (33 layers, 650M params, UR50D training set)
model_name = "facebook/esm2_t33_650M_UR50D"
# Load tokenizer and model; move model to Apple Metal Performance Shaders (MPS) device
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForMaskedLM.from_pretrained(model_name).to("mps")
# Number of sequences per forward pass
batch_size = 5
# Number of timing iterations
steps = 50
# Tokenize input sequence and replicate for the batch
# Replicate the same sequence 'batch_size' times to create a batch
inputs = tokenizer(
[protein_sequence] * batch_size,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
input_ids = inputs["input_ids"].to("mps")
attention_mask = inputs["attention_mask"].to("mps")
# Warm-up phase
for _ in range(10):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
torch.mps.synchronize() # Ensure all queued ops on MPS are complete before next step
# Timed inference loop
tic = time.time()
for _ in range(steps):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
torch.mps.synchronize() # Wait for computation to finish before timing next iteration
toc = time.time()
# Compute performance metrics
ms_per_step = 1000 * (toc - tic) / steps
throughput = batch_size * 1000 / ms_per_step
# Report results
print(f"Time (ms) per step: {ms_per_step:.3f}")
print(f"Throughput: {throughput:.2f} sequences/sec")

177
esm/convert.py Normal file
View File

@@ -0,0 +1,177 @@
import argparse
import json
import shutil
from pathlib import Path
from typing import Dict
import mlx.core as mx
import torch
from huggingface_hub import snapshot_download
def download(hf_repo: str) -> Path:
"""Download model from Hugging Face."""
return Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.safetensors", "*.json", "*.bin", "*.txt"],
)
)
def remap_key(key: str) -> str:
"""Remap HuggingFace ESM key names to MLX format."""
# Skip position embeddings and position_ids
if "position_embeddings" in key or "position_ids" in key:
return None
# Map lm_head components properly
if key == "lm_head.decoder.weight":
return "lm_head.weight"
if key == "lm_head.decoder.bias":
return "lm_head.bias"
if key == "lm_head.dense.weight":
return "lm_head.dense.weight"
if key == "lm_head.dense.bias":
return "lm_head.dense.bias"
if key == "lm_head.layer_norm.weight":
return "lm_head.layer_norm.weight"
if key == "lm_head.layer_norm.bias":
return "lm_head.layer_norm.bias"
# Core remapping patterns
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
key = key.replace("esm.encoder.layer.", "layer_")
key = key.replace("esm.contact_head", "contact_head")
key = key.replace("lm_head", "lm_head")
# Attention patterns
key = key.replace(".attention.self.", ".self_attn.")
key = key.replace(".attention.output.dense", ".self_attn.out_proj")
key = key.replace(".attention.LayerNorm", ".self_attn_layer_norm")
key = key.replace(".query", ".q_proj")
key = key.replace(".key", ".k_proj")
key = key.replace(".value", ".v_proj")
key = key.replace(".rotary_embeddings", ".rot_emb")
# FFN patterns
key = key.replace(".intermediate.dense", ".fc1")
key = key.replace(".output.dense", ".fc2")
key = key.replace(".LayerNorm", ".final_layer_norm")
return key
def load_weights(model_path: Path) -> Dict:
"""Load weights from safetensors or PyTorch bin files."""
# Check for safetensors file
safetensors_path = model_path / "model.safetensors"
if safetensors_path.exists():
print("Loading from safetensors...")
return mx.load(str(safetensors_path))
# Check for single bin file
single_bin_path = model_path / "pytorch_model.bin"
if single_bin_path.exists():
print("Loading from pytorch_model.bin...")
state_dict = torch.load(str(single_bin_path), map_location="cpu")
return {k: v.numpy() for k, v in state_dict.items()}
# Check for sharded bin files
index_file = model_path / "pytorch_model.bin.index.json"
if index_file.exists():
print("Loading from sharded bin files...")
with open(index_file) as f:
index = json.load(f)
# Get unique shard files
shard_files = set(index["weight_map"].values())
# Load all shards
state_dict = {}
for shard_file in sorted(shard_files):
print(f" Loading shard: {shard_file}")
shard_path = model_path / shard_file
shard_dict = torch.load(str(shard_path), map_location="cpu")
state_dict.update(shard_dict)
return {k: v.numpy() for k, v in state_dict.items()}
raise ValueError(f"No model weights found in {model_path}")
def convert(model_path: Path) -> Dict[str, mx.array]:
"""Convert ESM weights to MLX format."""
# Load weights from any format
weights = load_weights(model_path)
# Convert keys and create MLX arrays
mlx_weights = {}
for key, value in weights.items():
mlx_key = remap_key(key)
if mlx_key is not None:
mlx_weights[mlx_key] = (
mx.array(value) if not isinstance(value, mx.array) else value
)
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
return mlx_weights
def main():
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
parser.add_argument(
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
)
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
parser.add_argument(
"--checkpoints-dir",
default="checkpoints",
help="Directory to save checkpoints (default: checkpoints)",
)
args = parser.parse_args()
# Download model
print(f"Downloading {args.hf_path}...")
model_path = download(args.hf_path)
# Set output path
if args.mlx_path is None:
model_name = args.hf_path.split("/")[-1]
checkpoints_dir = Path(args.checkpoints_dir)
checkpoints_dir.mkdir(parents=True, exist_ok=True)
args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
# Convert weights
print("Converting weights...")
mlx_weights = convert(model_path)
# Save weights
print(f"Saving MLX weights to {mlx_path}...")
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
# Copy config and other files
print("Copying config...")
shutil.copy(model_path / "config.json", mlx_path / "config.json")
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
src_file = model_path / file_name
if src_file.exists():
shutil.copy(src_file, mlx_path / file_name)
print(f"Conversion complete! MLX model saved to {mlx_path}")
if __name__ == "__main__":
main()

19
esm/esm/__init__.py Normal file
View File

@@ -0,0 +1,19 @@
"""
ESM-2 protein language model implementation in MLX
"""
from .attention import MultiheadAttention
from .model import ESM2
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .rotary_embedding import RotaryEmbedding
from .tokenizer import ProteinTokenizer
__all__ = [
"ESM2",
"ProteinTokenizer",
"ContactPredictionHead",
"RobertaLMHead",
"TransformerLayer",
"MultiheadAttention",
"RotaryEmbedding",
]

153
esm/esm/attention.py Normal file
View File

@@ -0,0 +1,153 @@
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .rotary_embedding import RotaryEmbedding
class MultiheadAttention(nn.Module):
"""
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
This module implements both self-attention (when `key` and `value` are not
provided) and cross-attention. It projects input sequences into queries,
keys, and values, applies rotary position embeddings to encode relative
position information, computes scaled dot-product attention over multiple
heads in parallel, and returns a combined output projection.
Args:
embed_dim (int): Total embedding dimension of the model input and output.
num_heads (int): Number of parallel attention heads. Must divide `embed_dim`.
"""
def __init__(
self,
embed_dim,
num_heads,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
# Linear projections for queries, keys, and values (with bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# Linear projection for output (with bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# ESM-2 uses rotary embeddings
self.rot_emb = RotaryEmbedding(dim=self.head_dim)
def __call__(
self,
query,
key: Optional[mx.array] = None,
value: Optional[mx.array] = None,
key_padding_mask: Optional[mx.array] = None,
attn_mask: Optional[mx.array] = None,
need_head_weights: bool = False,
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Multi-head attention forward pass.
Args:
query: Tensor of shape (tgt_len, batch, embed_dim).
key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
value: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
key_padding_mask: Optional mask of shape (batch, src_len) to ignore padded positions.
attn_mask: Optional mask for attention (e.g., causal mask).
need_head_weights: If True, return attention weights for each head separately.
Returns:
attn_output: Tensor of shape (tgt_len, batch, embed_dim).
attn_weights_out: Attention weights of shape
(num_heads, batch, tgt_len, src_len) if per-head,
or (batch, tgt_len, src_len) if averaged.
"""
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
# For self-attention, use query as key and value if not provided
if key is None:
key = query
if value is None:
value = query
# Project queries, keys, values
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = q * self.scaling
# Reshape for multi-head attention
q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
src_len = k.shape[1]
# Apply rotary embeddings if present
if self.rot_emb:
q, k = self.rot_emb(q, k)
# Compute attention weights
attn_weights = q @ k.swapaxes(-2, -1)
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
# Apply attention mask
if attn_mask is not None:
attn_mask = mx.expand_dims(attn_mask, 0)
attn_weights = attn_weights + attn_mask
# Apply key padding mask
if key_padding_mask is not None:
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
# Convert key_padding_mask to boolean and expand dimensions
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
mask = mx.expand_dims(
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
)
# Apply mask: set attention to -inf where mask is True (padded positions)
attn_weights = mx.where(mask, -mx.inf, attn_weights)
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
# Apply softmax
attn_weights_float = mx.softmax(attn_weights.astype(mx.float32), axis=-1)
attn_probs = attn_weights_float
# Compute attention output
attn = attn_probs @ v
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
# Reshape output
attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
# Return attention weights if requested
attn_weights_out: Optional[mx.array] = None
if need_head_weights:
# Return attention weights for each head separately
attn_weights_out = (
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
.astype(attn.dtype)
.swapaxes(0, 1)
)
else:
# Return averaged attention weights
attn_weights_out = mx.mean(
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
axis=1,
).astype(attn.dtype)
return attn, attn_weights_out

340
esm/esm/model.py Normal file
View File

@@ -0,0 +1,340 @@
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
from .tokenizer import ProteinTokenizer
class ESM2(nn.Module):
"""
ESM-2 protein language model in MLX.
Args:
num_layers (int): Number of transformer layers.
embed_dim (int): Embedding dimension.
attention_heads (int): Number of attention heads.
tokenizer (Optional[ProteinTokenizer]): Tokenizer to use (created if None).
token_dropout (bool): Apply token-dropout masking behavior.
"""
def __init__(
self,
num_layers: int = 33,
embed_dim: int = 1280,
attention_heads: int = 20,
tokenizer: Optional[ProteinTokenizer] = None,
token_dropout: bool = True,
):
super().__init__()
self.num_layers = num_layers
self.embed_dim = embed_dim
self.attention_heads = attention_heads
# Initialize tokenizer
if tokenizer is None:
tokenizer = ProteinTokenizer()
self.tokenizer = tokenizer
self.vocab_size = len(tokenizer)
# Special token IDs / config
self.padding_idx = tokenizer.pad_id
self.mask_idx = tokenizer.mask_id
self.cls_idx = tokenizer.cls_id
self.eos_idx = tokenizer.eos_id
self.prepend_bos = tokenizer.prepend_bos
self.append_eos = tokenizer.append_eos
self.token_dropout = token_dropout
self._init_submodules()
def _init_submodules(self) -> None:
"""Initialize embeddings, transformer stack, and output heads."""
self.embed_scale = 1
# Token embeddings
self.embed_tokens = nn.Embedding(self.vocab_size, self.embed_dim)
# Transformer layers (register each layer so MLX tracks parameters)
self._layer_indices = list(range(self.num_layers))
for i in self._layer_indices:
layer = TransformerLayer(
self.embed_dim,
4 * self.embed_dim, # FFN dimension = 4×embed_dim
self.attention_heads,
)
setattr(self, f"layer_{i}", layer)
# Contact prediction head (uses all layers × heads attentions)
self.contact_head = ContactPredictionHead(
self.num_layers * self.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
# Final norm + LM head (tied weights)
self.emb_layer_norm_after = nn.LayerNorm(self.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.embed_dim,
output_dim=self.vocab_size,
weight=self.embed_tokens.weight,
)
def __call__(
self,
tokens: mx.array,
repr_layers: List[int] = [],
need_head_weights: bool = False,
return_contacts: bool = False,
) -> Dict[str, mx.array]:
"""
Forward pass through ESM-2.
Args:
tokens: Tensor of token IDs with shape (B, T).
repr_layers: Layers to return hidden states from (0..num_layers).
need_head_weights: If True, return attention weights.
return_contacts: If True, compute residue-residue contact probabilities.
Returns:
dict with:
logits: (B, T, V)
representations: {layer_idx: (B, T, E)}
attentions: (B, L, H, T, T) if requested
contacts: (B, T', T') if requested
"""
if return_contacts:
need_head_weights = True
# Ensure tokens is 2D (B, T)
if tokens.ndim == 1:
tokens = mx.expand_dims(tokens, axis=0)
assert tokens.ndim == 2
# Padding mask (B, T)
padding_mask = mx.equal(tokens, self.padding_idx)
# Embed tokens (B, T, E)
x = self.embed_scale * self.embed_tokens(tokens)
# Token dropout: zero masked tokens + rescale based on observed mask ratio
if self.token_dropout:
# x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
mask_positions = mx.equal(tokens, self.mask_idx)
x = mx.where(mask_positions[:, :, None], 0.0, x)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = mx.sum(~padding_mask, axis=-1) # Shape: (B,)
mask_ratio_observed = mx.sum(mask_positions, axis=-1).astype(x.dtype) / src_lengths # Shape: (B,)
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
# Zero out padding positions
if padding_mask.any():
x = x * (1 - padding_mask[:, :, None].astype(x.dtype))
# Track requested representations
repr_layers = set(repr_layers)
hidden_representations: Dict[int, mx.array] = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
attn_weights: List[mx.array] = []
# (B, T, E) -> (T, B, E) for transformer layers
x = mx.swapaxes(x, 0, 1)
# If no padding anywhere, skip the mask
if not padding_mask.any():
padding_mask = None
# Transformer stack
for layer_idx in self._layer_indices:
layer = getattr(self, f"layer_{layer_idx}")
x, attn = layer(
x,
self_attn_padding_mask=padding_mask,
need_head_weights=need_head_weights,
)
# Save hidden representation if requested (store back as (B, T, E))
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = mx.swapaxes(x, 0, 1)
# Save per-layer attentions if requested (H, B, T, T) -> (B, H, T, T)
if need_head_weights:
attn_weights.append(mx.swapaxes(attn, 0, 1))
# Final layer norm, back to (B, T, E)
x = self.emb_layer_norm_after(x)
x = mx.swapaxes(x, 0, 1)
# Save final hidden if requested
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
# Language modeling logits
x = self.lm_head(x)
# Build result dict
result: Dict[str, mx.array] = {
"logits": x,
"representations": hidden_representations,
}
# Collect attentions and optional contacts
if need_head_weights:
# Stack layers -> (B, L, H, T, T)
attentions = mx.stack(attn_weights, axis=1)
# Mask out padded positions if present
if padding_mask is not None:
attention_mask = 1 - padding_mask.astype(attentions.dtype)
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
attention_mask, 2
)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions
# Compute contacts if requested
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts
return result
def predict_contacts(self, tokens: mx.array) -> mx.array:
"""
Predict residue-residue contacts.
Args:
tokens: Tensor of shape (B, T).
Returns:
mx.array: Contact probabilities of shape (B, T', T').
"""
return self(tokens, return_contacts=True)["contacts"]
def extract_features(
self,
tokens: mx.array,
repr_layers: Optional[List[int]] = None,
return_all_hiddens: bool = False,
) -> Dict[int, mx.array]:
"""
Extract hidden representations from selected layers.
Args:
tokens: Tensor of shape (B, T).
repr_layers: Layer indices to return (default: last layer).
return_all_hiddens: If True, return all layers (0..num_layers).
Returns:
dict: {layer_idx: (B, T, E)} for requested layers.
"""
if return_all_hiddens:
repr_layers = list(range(self.num_layers + 1))
elif repr_layers is None:
repr_layers = [self.num_layers]
result = self(tokens, repr_layers=repr_layers)
return result["representations"]
def get_sequence_representations(
self,
tokens: mx.array,
layer: int = -1,
) -> mx.array:
"""
Average token representations into a per-sequence embedding.
Args:
tokens: Tensor of shape (B, T).
layer: Layer index to use (-1 or num_layers for last).
Returns:
mx.array: Sequence embeddings of shape (B, E).
"""
if layer == -1:
layer = self.num_layers
representations = self.extract_features(tokens, repr_layers=[layer])
repr = representations[layer]
# Mask: non-padding and not CLS; optionally not EOS
mask = mx.logical_and(
mx.not_equal(tokens, self.padding_idx),
mx.not_equal(tokens, self.cls_idx),
)
if self.append_eos:
mask = mx.logical_and(mask, mx.not_equal(tokens, self.eos_idx))
# Mean over valid positions
mask = mask[:, :, None].astype(repr.dtype)
masked_repr = repr * mask
seq_lens = mx.sum(mask, axis=1, keepdims=True)
seq_repr = mx.sum(masked_repr, axis=1) / mx.maximum(seq_lens[:, :, 0], 1.0)
return seq_repr
@classmethod
def from_pretrained(cls, model_path: str) -> Tuple[ProteinTokenizer, "ESM2"]:
"""
Load model weights and config from a directory.
Expects:
- config.json
- model.safetensors
- vocab.txt (optional, will use default if not found)
- special_tokens_map.json (optional, will use default if not found)
Args:
model_path: Path to directory with weights and config.
Returns:
(tokenizer, model): Initialized tokenizer and ESM2 model.
"""
model_dir = Path(model_path)
config_path = model_dir / "config.json"
with open(config_path, "r") as f:
config = json.load(f)
# Check for vocab and special tokens files
vocab_path = model_dir / "vocab.txt"
special_tokens_path = model_dir / "special_tokens_map.json"
if vocab_path.exists() and special_tokens_path.exists():
tokenizer = ProteinTokenizer(
vocab_file=str(vocab_path),
special_tokens_map_file=str(special_tokens_path),
)
else:
tokenizer = ProteinTokenizer()
model = cls(
num_layers=config["num_hidden_layers"],
embed_dim=config["hidden_size"],
attention_heads=config["num_attention_heads"],
tokenizer=tokenizer,
token_dropout=config["token_dropout"],
)
# Load safetensors as nested dict and update model params
weights_path = model_dir / "model.safetensors"
flat_weights = mx.load(str(weights_path))
nested_weights: Dict[str, dict] = {}
for key, value in flat_weights.items():
parts = key.split(".")
cur = nested_weights
for p in parts[:-1]:
cur = cur.setdefault(p, {})
cur[parts[-1]] = value
model.update(nested_weights)
return tokenizer, model

212
esm/esm/modules.py Normal file
View File

@@ -0,0 +1,212 @@
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .attention import MultiheadAttention
def symmetrize(x: mx.array) -> mx.array:
"""
Make a tensor symmetric over its last two dimensions.
Args:
x: Tensor with shape (..., L, L).
Returns:
mx.array: Symmetrized tensor of shape (..., L, L).
"""
# Add tensor to its transpose over the last two dims
return x + mx.swapaxes(x, -1, -2)
def apc(x: mx.array) -> mx.array:
"""
Apply Average Product Correction (APC) to remove background co-variation.
Args:
x: Tensor with shape (..., L, L).
Returns:
mx.array: APC-corrected tensor of shape (..., L, L).
"""
# Compute row, column, and total sums
a1 = mx.sum(x, axis=-1, keepdims=True)
a2 = mx.sum(x, axis=-2, keepdims=True)
a12 = mx.sum(x, axis=(-1, -2), keepdims=True)
# Expected co-variation under independence
expected = (a1 * a2) / a12
return x - expected
class TransformerLayer(nn.Module):
"""
Transformer layer used in ESM-2.
Args:
embed_dim (int): Model embedding dimension.
ffn_embed_dim (int): Hidden dimension of the feed-forward network.
attention_heads (int): Number of attention heads.
"""
def __init__(
self,
embed_dim: int,
ffn_embed_dim: int,
attention_heads: int,
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self._init_submodules()
def _init_submodules(self) -> None:
"""Initialize attention, norms, and feed-forward submodules."""
self.self_attn = MultiheadAttention(self.embed_dim, self.attention_heads)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def __call__(
self,
x: mx.array,
self_attn_mask: Optional[mx.array] = None,
self_attn_padding_mask: Optional[mx.array] = None,
need_head_weights: bool = False,
):
"""
Forward pass for the Transformer layer.
Args:
x: Tensor of shape (seq_len, batch, embed_dim).
self_attn_mask: Optional attention mask.
self_attn_padding_mask: Optional padding mask of shape (batch, seq_len).
need_head_weights: If True, return per-head attention weights.
Returns:
x: Tensor of shape (seq_len, batch, embed_dim).
attn: Attention weights of shape
(num_heads, batch, tgt_len, src_len) if per-head,
or (batch, tgt_len, src_len) if averaged.
"""
# Self-attention block
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
need_head_weights=need_head_weights,
)
x = residual + x
# Feed-forward block
residual = x
x = self.final_layer_norm(x)
x = nn.gelu(self.fc1(x))
x = self.fc2(x)
x = residual + x
return x, attn
class RobertaLMHead(nn.Module):
"""
Masked Language Modeling (MLM) head with tied weights.
Args:
embed_dim (int): Embedding dimension of the backbone.
output_dim (int): Vocabulary size.
weight (mx.array): Embedding weight matrix for tied projection.
"""
def __init__(self, embed_dim: int, output_dim: int, weight: mx.array):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.layer_norm = nn.LayerNorm(embed_dim)
self.weight = weight
self.bias = mx.zeros(output_dim)
def __call__(self, features: mx.array) -> mx.array:
"""
Forward pass for the MLM head.
Args:
features: Tensor of shape (seq_len, batch, embed_dim).
Returns:
mx.array: Logits of shape (seq_len, batch, output_dim).
"""
# Transform features before projection to vocab
x = self.dense(features)
x = nn.gelu(x)
x = self.layer_norm(x)
return mx.matmul(x, self.weight.T) + self.bias
class ContactPredictionHead(nn.Module):
"""
Predict residue-residue contact probabilities from attention maps.
Args:
in_features (int): Number of attention channels (layers × heads).
prepend_bos (bool): If True, drop BOS/CLS token attentions.
append_eos (bool): If True, drop EOS token attentions.
bias (bool): Whether the regression layer uses a bias term.
eos_idx (Optional[int]): Token ID for EOS; required if append_eos=True.
"""
def __init__(
self,
in_features: int,
prepend_bos: bool,
append_eos: bool,
bias: bool = True,
eos_idx: Optional[int] = None,
):
super().__init__()
self.in_features = in_features
self.prepend_bos = prepend_bos
self.append_eos = append_eos
if append_eos and eos_idx is None:
raise ValueError("append_eos=True but eos_idx was not provided.")
self.eos_idx = eos_idx
self.regression = nn.Linear(in_features, 1, bias=bias)
def __call__(self, tokens: mx.array, attentions: mx.array) -> mx.array:
"""
Forward pass for contact prediction.
Args:
tokens: Tensor of shape (B, T).
attentions: Tensor of shape (B, L, H, T, T).
Returns:
mx.array: Contact probabilities of shape (B, T', T'),
where T' = T - [prepend_bos] - [append_eos].
"""
# Remove EOS attentions if requested
if self.append_eos:
eos_mask = mx.not_equal(tokens, self.eos_idx).astype(attentions.dtype)
eos_mask = eos_mask[:, None, :] * eos_mask[:, :, None]
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
# Remove BOS attentions if requested
if self.prepend_bos:
attentions = attentions[..., 1:, 1:]
# Merge (layers × heads) into channel dimension
batch_size, layers, heads, seqlen, _ = attentions.shape
attentions = attentions.reshape(batch_size, layers * heads, seqlen, seqlen)
# Symmetrize and apply APC to enhance contact signal
attentions = apc(symmetrize(attentions))
# Apply logistic regression over channels
attentions = mx.transpose(attentions, axes=[0, 2, 3, 1])
logits = self.regression(attentions)
return nn.sigmoid(mx.squeeze(logits, axis=3))

114
esm/esm/rotary_embedding.py Normal file
View File

@@ -0,0 +1,114 @@
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
def rotate_half(x: mx.array) -> mx.array:
"""
Rotate last dimension by splitting into two halves and swapping.
Args:
x: Tensor with even-sized last dimension.
Returns:
mx.array: Tensor of same shape as `x` with halves rotated.
"""
# Split into two equal halves along the last dimension
x1, x2 = mx.split(x, 2, axis=-1)
# Swap halves and negate the second half
return mx.concatenate((-x2, x1), axis=-1)
def apply_rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
"""
Apply rotary position embeddings to a tensor.
Args:
x: Input tensor of shape (..., seq_len, dim).
cos: Cosine embedding table of shape (1, seq_len, dim).
sin: Sine embedding table of shape (1, seq_len, dim).
Returns:
mx.array: Tensor with rotary position embeddings applied.
"""
# Trim cos/sin to match the sequence length of x
cos = cos[:, : x.shape[-2], :]
sin = sin[:, : x.shape[-2], :]
# Elementwise rotation: (x * cos) + (rotate_half(x) * sin)
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(nn.Module):
"""
Rotary position embedding (RoPE) module.
Args:
dim (int): Head dimension size (must be even).
"""
def __init__(self, dim: int, *_, **__):
super().__init__()
# Precompute inverse frequency for each pair of dimensions
self.inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
# Cache for cosine/sine tables to avoid recomputation
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(
self, x: mx.array, seq_dimension: int = 1
) -> Tuple[mx.array, mx.array]:
"""
Compute and cache cos/sin tables for the given sequence length.
Args:
x: Reference tensor for sequence length.
seq_dimension: Axis containing the sequence length.
Returns:
Tuple of:
cos: Cosine table of shape (1, seq_len, dim).
sin: Sine table of shape (1, seq_len, dim).
"""
seq_len = x.shape[seq_dimension]
# Only update cache if sequence length has changed
if seq_len != self._seq_len_cached:
self._seq_len_cached = seq_len
# Time steps: shape (seq_len,)
t = mx.arange(seq_len).astype(self.inv_freq.dtype)
# Outer product between time and inverse frequency
freqs = mx.einsum("i,j->ij", t, self.inv_freq)
# Duplicate frequencies for cos/sin dimensions
emb = mx.concatenate((freqs, freqs), axis=-1)
self._cos_cached = mx.cos(emb)[None, :, :]
self._sin_cached = mx.sin(emb)[None, :, :]
return self._cos_cached, self._sin_cached
def __call__(self, q: mx.array, k: mx.array) -> Tuple[mx.array, mx.array]:
"""
Apply rotary position embeddings to queries and keys.
Args:
q: Query tensor of shape (..., seq_len, dim).
k: Key tensor of shape (..., seq_len, dim).
Returns:
Tuple of:
q_rot: Query tensor with RoPE applied.
k_rot: Key tensor with RoPE applied.
"""
# Get (and cache) cos/sin tables based on key sequence length
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=-2
)
# Apply rotary embeddings to both q and k
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)

241
esm/esm/tokenizer.py Normal file
View File

@@ -0,0 +1,241 @@
import json
from pathlib import Path
from typing import List, Optional, Sequence, Union
import mlx.core as mx
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
PROTEIN_TOKENS = [
"L",
"A",
"G",
"V",
"S",
"E",
"R",
"T",
"I",
"D",
"P",
"K",
"Q",
"N",
"F",
"Y",
"M",
"H",
"W",
"C",
"X",
"B",
"U",
"Z",
"O",
".",
"-",
]
ArrayLike = Union[List[int], mx.array]
class ProteinTokenizer:
"""
Protein sequence tokenizer compatible with ESM-2.
This class converts protein sequences into token IDs and back, following
the vocabulary, special tokens, and formatting rules used by ESM-2.
"""
def __init__(
self,
vocab_file: Optional[str] = None,
special_tokens_map_file: Optional[str] = None,
):
"""
Initialize the ProteinTokenizer.
Args:
vocab_file: Optional path to a file containing the vocabulary,
one token per line.
special_tokens_map_file: Optional path to a JSON file defining
special token names and values.
If both files are provided, they override the default vocabulary and
special token mappings. Otherwise, defaults are loaded.
"""
# Load vocabulary from files if given, otherwise use built-in defaults
if vocab_file and special_tokens_map_file:
self._load_from_files(vocab_file, special_tokens_map_file)
else:
self._load_default_vocab()
# Create token ↔ ID mappings
self.token_to_id = {tok: i for i, tok in enumerate(self.vocab)}
self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)}
# Cache special token IDs
self.cls_id = self.token_to_id["<cls>"]
self.pad_id = self.token_to_id["<pad>"]
self.eos_id = self.token_to_id["<eos>"]
self.unk_id = self.token_to_id["<unk>"]
self.mask_id = self.token_to_id["<mask>"]
# Behavior flags for ESM-2-style BOS/EOS
self.prepend_bos = True
self.append_eos = True
def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None:
"""Load vocabulary and special tokens from the provided files."""
# Vocabulary file: one token per line
vocab_path = Path(vocab_file)
with open(vocab_path, "r", encoding="utf-8") as f:
self.vocab = [line.strip() for line in f if line.strip()]
# Special tokens mapping file (JSON)
special_tokens_path = Path(special_tokens_map_file)
with open(special_tokens_path, "r", encoding="utf-8") as f:
self.special_tokens_map = json.load(f)
def _load_default_vocab(self) -> None:
"""Load the built-in ESM vocabulary and special token mapping."""
# ESM convention: prepend special tokens, then amino acids, then <mask>
prepend_toks = ["<cls>", "<pad>", "<eos>", "<unk>"]
append_toks = ["<mask>"]
self.vocab = prepend_toks + PROTEIN_TOKENS
# Pad vocab size to multiple of 8 (original implementation detail)
while len(self.vocab) % 8 != 0:
self.vocab.append(f"<null_{len(self.vocab) - len(prepend_toks)}>")
self.vocab.extend(append_toks)
# Default special tokens map
self.special_tokens_map = {
"cls_token": "<cls>",
"pad_token": "<pad>",
"eos_token": "<eos>",
"unk_token": "<unk>",
"mask_token": "<mask>",
}
def encode(
self,
sequence: str,
*,
add_special_tokens: bool = True,
return_batch_dim: bool = False,
dtype=mx.int32,
) -> mx.array:
"""
Convert a protein sequence into token IDs.
Args:
sequence: Protein sequence (case-insensitive).
add_special_tokens: If True, add <cls> at the start and <eos> at the end.
return_batch_dim: If True, output shape will be (1, L) instead of (L,).
dtype: MLX dtype for the returned array.
Returns:
mx.array: Token IDs of shape (L,) or (1, L).
"""
ids: List[int] = []
if add_special_tokens and self.prepend_bos:
ids.append(self.cls_id)
# Map each residue to its ID (defaulting to <unk> if not in vocab)
for ch in sequence.upper():
ids.append(self.token_to_id.get(ch, self.unk_id))
if add_special_tokens and self.append_eos:
ids.append(self.eos_id)
arr = mx.array(ids, dtype=dtype)
return mx.expand_dims(arr, axis=0) if return_batch_dim else arr
def batch_encode(
self,
sequences: Sequence[str],
*,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
dtype=mx.int32,
) -> mx.array:
"""
Encode multiple protein sequences into a padded batch.
Args:
sequences: List/sequence of protein strings.
add_special_tokens: If True, add <cls> and <eos> tokens.
max_length: If provided, truncate sequences to this length before padding.
dtype: MLX dtype for the returned array.
Returns:
mx.array: Tensor of shape (B, L) with right-padding using <pad> IDs.
"""
# Encode each sequence as (L,)
encoded = [
self.encode(s, add_special_tokens=add_special_tokens, dtype=dtype)
for s in sequences
]
encoded = [e if e.ndim == 1 else e[0] for e in encoded]
if max_length is not None:
encoded = [e[:max_length] for e in encoded]
# Find the longest sequence and right-pad all others
max_len = max((int(e.shape[0]) for e in encoded), default=0)
padded = []
for e in encoded:
pad_len = max_len - int(e.shape[0])
if pad_len > 0:
pad = mx.full((pad_len,), self.pad_id, dtype=dtype)
e = mx.concatenate([e, pad], axis=0)
padded.append(e)
return mx.stack(padded, axis=0) if padded else mx.array([], dtype=dtype)
def decode(
self,
token_ids: ArrayLike,
*,
skip_special_tokens: bool = False,
) -> str:
"""
Convert token IDs back into a protein sequence string.
Args:
token_ids: 1-D or 2-D array/list of IDs. If 2-D, only the first row is decoded.
skip_special_tokens: If True, remove all special tokens from output.
Returns:
str: Protein sequence.
"""
# Normalize to a 1-D MLX array
if hasattr(token_ids, "shape") and hasattr(token_ids, "tolist"):
ids = token_ids if token_ids.ndim == 1 else token_ids[0]
else:
ids = mx.array(token_ids, dtype=mx.int32)
ids_list = [int(x) for x in ids.tolist()]
toks: List[str] = []
for i in ids_list:
tok = self.id_to_token.get(i, "<unk>")
if skip_special_tokens and tok in {
"<cls>",
"<pad>",
"<eos>",
"<unk>",
"<mask>",
}:
continue
toks.append(tok)
return "".join(toks)
def __len__(self) -> int:
"""Return the size of the tokenizers vocabulary."""
return len(self.vocab)

81
esm/main.py Normal file
View File

@@ -0,0 +1,81 @@
import argparse
import mlx.core as mx
from esm import ESM2
def main():
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
parser.add_argument(
"--model-path",
default="checkpoints/mlx-esm2_t33_650M_UR50D",
help="Path to MLX model checkpoint",
)
parser.add_argument(
"--sequence",
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
help="Protein sequence to test (default: human insulin)",
)
parser.add_argument(
"--mask-position",
type=int,
default=None,
help="Position to mask (default: middle of sequence)",
)
args = parser.parse_args()
# Load pretrained ESM-2 model and tokenizer
tokenizer, model = ESM2.from_pretrained(args.model_path)
# Determine sequence and position to mask
sequence = args.sequence.upper()
mask_pos = (
args.mask_position if args.mask_position is not None else len(sequence) // 2
)
if mask_pos >= len(sequence):
mask_pos = len(sequence) - 1
original_aa = sequence[mask_pos] # The original amino acid at masked position
# Display input info
print(f"Original sequence: {sequence}")
print(f"Masked sequence: {sequence[:mask_pos]}<mask>{sequence[mask_pos+1:]}")
print(f"Predicting position {mask_pos}: {original_aa}\n")
# Tokenize sequence before and after the mask
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
# Build token sequence with <cls>, <mask>, and <eos>
tokens = mx.array(
[
[tokenizer.cls_id]
+ before.tolist()
+ [tokenizer.mask_id]
+ after.tolist()
+ [tokenizer.eos_id]
]
)
mask_token_pos = 1 + len(before) # Position of <mask> token
# Run model to get logits for each token position
logits = model(tokens)["logits"]
probs = mx.softmax(
logits[0, mask_token_pos, :]
) # Softmax over vocabulary at mask position
# Get top-5 most likely tokens
top_indices = mx.argsort(probs)[-5:][::-1]
# Print predictions
print("Top predictions:")
for i, idx in enumerate(top_indices):
token = tokenizer.vocab[int(idx)]
if token in tokenizer.vocab:
prob = float(probs[idx])
marker = "" if token == original_aa else " "
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,602 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3fbacbe4",
"metadata": {},
"source": [
"## Predicting Protein Contacts with ESM-2\n",
"\n",
"Understanding how amino acids interact within a folded protein is essential for grasping protein function and stability. Contact prediction, the task of identifying which residues are close together in three-dimensional space, is a key step in the sequence to structure process. ESM-2s learned attention patterns capture evolutionary signals that encode structural information, which allows the model to predict residue contacts directly from sequence data.\n",
"\n",
"In this notebook, we'll explore ESM-2's ability to predict protein contacts across three diverse proteins from different organisms:\n",
"\n",
"**Bacterial Transport:**\n",
"- **1a3a (PTS Mannitol Component)**: A phosphoenolpyruvate-dependent sugar phosphotransferase system component from *E. coli*, essential for carbohydrate metabolism\n",
"\n",
"**Stress Response:**\n",
"- **5ahw (Universal Stress Protein)**: A conserved stress response protein from *Mycolicibacterium smegmatis* that helps cells survive harsh conditions\n",
"\n",
"**Human Metabolism:**\n",
"- **1xcr (Ester Hydrolase)**: A human enzyme (C11orf54) involved in lipid metabolism and cellular signaling pathways\n",
"\n",
"We will evaluate how effectively ESM-2 captures the structural relationships present in these sequences, measuring precision across different sequence separation ranges to assess both local and long-range contact prediction performance. This notebook is a modified version of a [notebook by the same name](https://github.com/facebookresearch/esm/blob/main/examples/contact_prediction.ipynb) from the [offical ESM repsitory](https://github.com/facebookresearch/esm)."
]
},
{
"cell_type": "markdown",
"id": "08352b12",
"metadata": {},
"source": [
"### Setup\n",
"\n",
"Here we import all neccessary libraries."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1047c94",
"metadata": {},
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mRunning cells with '.venv (Python 3.11.13)' requires the ipykernel package.\n",
"\u001b[1;31mInstall 'ipykernel' into the Python environment. \n",
"\u001b[1;31mCommand: '/Users/vincent/Developer/mlx-examples/.venv/bin/python -m pip install ipykernel -U --force-reinstall'"
]
}
],
"source": [
"from typing import List, Tuple, Optional, Dict\n",
"import string\n",
"\n",
"import mlx.core as mx\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from scipy.spatial.distance import squareform, pdist\n",
"import biotite.structure as bs\n",
"from biotite.database import rcsb\n",
"from biotite.structure.io.pdbx import CIFFile, get_structure\n",
"from Bio import SeqIO"
]
},
{
"cell_type": "markdown",
"id": "5f0af076",
"metadata": {},
"source": [
"Download multiple sequence alignment (MSA) files for our three test proteins from the ESM repository."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3264b66d",
"metadata": {},
"outputs": [],
"source": [
"!mkdir -p data\n",
"!curl -o data/1a3a_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1a3a_1_A.a3m\n",
"!curl -o data/5ahw_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/5ahw_1_A.a3m\n",
"!curl -o data/1xcr_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1xcr_1_A.a3m"
]
},
{
"cell_type": "markdown",
"id": "cbf1d0cb",
"metadata": {},
"source": [
"### Loading the model\n",
"\n",
"Load the ESM-2 model. Here we will use the 650M parameter version. Change the path below to point to your converted checkpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4406e8a0",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"..\")\n",
"\n",
"from esm import ESM2\n",
"\n",
"esm_checkpoint = \"../checkpoints/mlx-esm2_t33_650M_UR50D\"\n",
"tokenizer, model = ESM2.from_pretrained(esm_checkpoint)"
]
},
{
"cell_type": "markdown",
"id": "77596456",
"metadata": {},
"source": [
"### Defining functions"
]
},
{
"cell_type": "markdown",
"id": "eb5f07ed",
"metadata": {},
"source": [
"#### Parsing alignments"
]
},
{
"cell_type": "markdown",
"id": "e754abd7",
"metadata": {},
"source": [
"This function parses multiple sequence alignment files and clean up insertion artifacts. MSA files often contain lowercase letters and special characters (`.`, `*`) to indicate insertions relative to the reference sequence - these need to be removed to get the core aligned sequences."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43717bea",
"metadata": {},
"outputs": [],
"source": [
"deletekeys = dict.fromkeys(string.ascii_lowercase)\n",
"deletekeys[\".\"] = None\n",
"deletekeys[\"*\"] = None\n",
"translation = str.maketrans(deletekeys)\n",
"\n",
"def read_sequence(filename: str) -> Tuple[str, str]:\n",
" \"\"\" Reads the first (reference) sequences from a fasta or MSA file.\"\"\"\n",
" record = next(SeqIO.parse(filename, \"fasta\"))\n",
" return record.description, str(record.seq)\n",
"\n",
"def remove_insertions(sequence: str) -> str:\n",
" \"\"\" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. \"\"\"\n",
" return sequence.translate(translation)\n",
"\n",
"def read_msa(filename: str) -> List[Tuple[str, str]]:\n",
" \"\"\" Reads the sequences from an MSA file, automatically removes insertions.\"\"\"\n",
" return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, \"fasta\")]"
]
},
{
"cell_type": "markdown",
"id": "628d7de1",
"metadata": {},
"source": [
"#### Converting structures to contacts\n",
"\n",
"There are many ways to define a protein contact. Here we're using the definition of 8 angstroms between carbon beta atoms. Note that the position of the carbon beta is imputed from the position of the N, CA, and C atoms for each residue."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21e0b44b",
"metadata": {},
"outputs": [],
"source": [
"def extend(a, b, c, L, A, D):\n",
" \"\"\"\n",
" input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral\n",
" output: 4th coord\n",
" \"\"\"\n",
" def normalize(x):\n",
" return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)\n",
"\n",
" bc = normalize(b - c)\n",
" n = normalize(np.cross(b - a, bc))\n",
" m = [bc, np.cross(n, bc), n]\n",
" d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]\n",
" return c + sum([m * d for m, d in zip(m, d)])\n",
"\n",
"def contacts_from_pdb(\n",
" structure: bs.AtomArray,\n",
" distance_threshold: float = 8.0,\n",
" chain: Optional[str] = None,\n",
") -> np.ndarray:\n",
" \"\"\"Extract contacts from PDB structure.\"\"\"\n",
" mask = ~structure.hetero\n",
" if chain is not None:\n",
" mask &= structure.chain_id == chain\n",
"\n",
" N = structure.coord[mask & (structure.atom_name == \"N\")]\n",
" CA = structure.coord[mask & (structure.atom_name == \"CA\")]\n",
" C = structure.coord[mask & (structure.atom_name == \"C\")]\n",
"\n",
" Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)\n",
" dist = squareform(pdist(Cbeta))\n",
" \n",
" contacts = dist < distance_threshold\n",
" contacts = contacts.astype(np.int64)\n",
" contacts[np.isnan(dist)] = -1\n",
" return contacts"
]
},
{
"cell_type": "markdown",
"id": "5473f306",
"metadata": {},
"source": [
"#### Computing contact precisions"
]
},
{
"cell_type": "markdown",
"id": "e361a9f3",
"metadata": {},
"source": [
"Calculate precision metrics to evaluate contact prediction quality. The `compute_precisions` function ranks predicted contacts by confidence and measures how many of the top predictions are true contacts, while `evaluate_prediction` breaks this down by sequence separation ranges (local, short, medium, long-range) since predicting distant contacts is typically much harder than nearby ones."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62c37bbd",
"metadata": {},
"outputs": [],
"source": [
"def compute_precisions(\n",
" predictions: mx.array,\n",
" targets: mx.array,\n",
" minsep: int = 6,\n",
" maxsep: Optional[int] = None,\n",
" override_length: Optional[int] = None,\n",
") -> Dict[str, mx.array]:\n",
" \"\"\"Compute precision metrics for contact prediction.\"\"\"\n",
" batch_size, seqlen, _ = predictions.shape\n",
" \n",
" if maxsep is not None:\n",
" sep_mask_2d = mx.abs(mx.arange(seqlen)[None, :] - mx.arange(seqlen)[:, None]) <= maxsep\n",
" targets = targets * sep_mask_2d[None, :]\n",
" \n",
" targets = targets.astype(mx.float32)\n",
" src_lengths = (targets >= 0).sum(axis=-1).sum(axis=-1).astype(mx.float32)\n",
" \n",
" x_ind, y_ind = [], []\n",
" for i in range(seqlen):\n",
" for j in range(i + minsep, seqlen):\n",
" x_ind.append(i)\n",
" y_ind.append(j)\n",
" \n",
" x_ind = mx.array(x_ind)\n",
" y_ind = mx.array(y_ind)\n",
" \n",
" predictions_upper = predictions[:, x_ind, y_ind]\n",
" targets_upper = targets[:, x_ind, y_ind]\n",
"\n",
" topk = seqlen if override_length is None else max(seqlen, override_length)\n",
" indices = mx.argsort(predictions_upper, axis=-1)[:, ::-1][:, :topk]\n",
" \n",
" batch_indices = mx.arange(batch_size)[:, None]\n",
" topk_targets = targets_upper[batch_indices, indices]\n",
" \n",
" if topk_targets.shape[1] < topk:\n",
" pad_shape = (topk_targets.shape[0], topk - topk_targets.shape[1])\n",
" padding = mx.zeros(pad_shape)\n",
" topk_targets = mx.concatenate([topk_targets, padding], 1)\n",
"\n",
" cumulative_dist = mx.cumsum(topk_targets, -1)\n",
"\n",
" gather_lengths = src_lengths[:, None]\n",
" if override_length is not None:\n",
" gather_lengths = override_length * mx.ones_like(gather_lengths)\n",
"\n",
" precision_fractions = mx.arange(0.1, 1.1, 0.1)\n",
" gather_indices = (precision_fractions[None, :] * gather_lengths) - 1\n",
" gather_indices = mx.clip(gather_indices, 0, cumulative_dist.shape[1] - 1)\n",
" gather_indices = gather_indices.astype(mx.int32)\n",
"\n",
" binned_cumulative_dist = cumulative_dist[batch_indices, gather_indices]\n",
" binned_precisions = binned_cumulative_dist / (gather_indices + 1)\n",
"\n",
" pl5 = binned_precisions[:, 1]\n",
" pl2 = binned_precisions[:, 4]\n",
" pl = binned_precisions[:, 9]\n",
" auc = binned_precisions.mean(-1)\n",
"\n",
" return {\"AUC\": auc, \"P@L\": pl, \"P@L2\": pl2, \"P@L5\": pl5}\n",
"\n",
"def evaluate_prediction(\n",
" predictions: mx.array,\n",
" targets: mx.array,\n",
") -> Dict[str, float]:\n",
" \"\"\"Evaluate contact predictions across different sequence separation ranges.\"\"\"\n",
" contact_ranges = [\n",
" (\"local\", 3, 6),\n",
" (\"short\", 6, 12),\n",
" (\"medium\", 12, 24),\n",
" (\"long\", 24, None),\n",
" ]\n",
" metrics = {}\n",
" \n",
" for name, minsep, maxsep in contact_ranges:\n",
" rangemetrics = compute_precisions(\n",
" predictions,\n",
" targets,\n",
" minsep=minsep,\n",
" maxsep=maxsep,\n",
" )\n",
" for key, val in rangemetrics.items():\n",
" metrics[f\"{name}_{key}\"] = float(val[0])\n",
" return metrics"
]
},
{
"cell_type": "markdown",
"id": "5873e052",
"metadata": {},
"source": [
"#### Predicting contacts"
]
},
{
"cell_type": "markdown",
"id": "2d5778a9",
"metadata": {},
"source": [
"This function wraps the tokenization and model inference steps, converting a raw amino acid sequence into token IDs and passing them through ESM-2's contact prediction head to produce a contact probability matrix."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dddf31a7",
"metadata": {},
"outputs": [],
"source": [
"def predict_contacts(sequence: str, model, tokenizer) -> mx.array:\n",
" \"\"\" Predict contacts for a given sequence \"\"\"\n",
" tokens = tokenizer.encode(sequence)\n",
" contacts = model.predict_contacts(tokens)\n",
" return contacts"
]
},
{
"cell_type": "markdown",
"id": "62562401",
"metadata": {},
"source": [
"#### Plotting results\n",
"\n",
"This function visualizes contacts as a symmetric matrix where both axes index residue positions. The lower triangle shows the models confidence as a blue heatmap, with darker cells indicating higher confidence. The upper triangle overlays evaluation markers: blue dots are correctly predicted contacts (true positives), red dots are predicted but not real (false positives), and grey dots are real contacts the model missed (false negatives)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03e03791",
"metadata": {},
"outputs": [],
"source": [
"def plot_contacts_and_predictions(\n",
" predictions: mx.array,\n",
" contacts: np.ndarray,\n",
" ax,\n",
" title: str,\n",
" cmap: str = \"Blues\",\n",
" ms: float = 1,\n",
"):\n",
" \"\"\"Plot contact predictions and true contacts.\"\"\"\n",
" if isinstance(predictions, mx.array):\n",
" predictions = np.array(predictions)\n",
" \n",
" seqlen = contacts.shape[0]\n",
" relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))\n",
" bottom_mask = relative_distance < 0\n",
" masked_image = np.ma.masked_where(bottom_mask, predictions)\n",
" invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6\n",
" predictions_copy = predictions.copy()\n",
" predictions_copy[invalid_mask] = float(\"-inf\")\n",
"\n",
" topl_val = np.sort(predictions_copy.reshape(-1))[-seqlen]\n",
" pred_contacts = predictions_copy >= topl_val\n",
" true_positives = contacts & pred_contacts & ~bottom_mask\n",
" false_positives = ~contacts & pred_contacts & ~bottom_mask\n",
" other_contacts = contacts & ~pred_contacts & ~bottom_mask\n",
"\n",
" ax.imshow(masked_image, cmap=cmap)\n",
" ax.plot(*np.where(other_contacts), \"o\", c=\"grey\", ms=ms)\n",
" ax.plot(*np.where(false_positives), \"o\", c=\"r\", ms=ms)\n",
" ax.plot(*np.where(true_positives), \"o\", c=\"b\", ms=ms)\n",
" ax.set_title(title)\n",
" ax.axis(\"square\")\n",
" ax.set_xlim([0, seqlen])\n",
" ax.set_ylim([0, seqlen])"
]
},
{
"cell_type": "markdown",
"id": "9364c984",
"metadata": {},
"source": [
"### Predict and visualize\n",
"Here we'll use ESM-2 contact prediction on our three test proteins and evaluate the results. We'll compute precision metrics across different sequence separation ranges and create contact maps that visualize both the model's predictions and how well they match the true protein structures."
]
},
{
"cell_type": "markdown",
"id": "9fa9e59e",
"metadata": {},
"source": [
"#### Read Data"
]
},
{
"cell_type": "markdown",
"id": "7da50dc2",
"metadata": {},
"source": [
"Load experimental protein structures from the Protein Data Bank and extract true contact maps for evaluation, while also parsing the reference sequences from our MSA files that will serve as input to ESM-2."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d276137",
"metadata": {},
"outputs": [],
"source": [
"PDB_IDS = [\"1a3a\", \"5ahw\", \"1xcr\"]\n",
"\n",
"structures = {\n",
" name.lower(): get_structure(CIFFile.read(rcsb.fetch(name, \"cif\")))[0]\n",
" for name in PDB_IDS\n",
"}\n",
"\n",
"contacts = {\n",
" name: contacts_from_pdb(structure, chain=\"A\") \n",
" for name, structure in structures.items()\n",
"}\n",
"\n",
"msas = {\n",
" name: read_msa(f\"data/{name.lower()}_1_A.a3m\")\n",
" for name in PDB_IDS\n",
"}\n",
"\n",
"sequences = {\n",
" name: msa[0] for name, msa in msas.items()\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "4ce64f18",
"metadata": {},
"source": [
"#### ESM-2 predictions"
]
},
{
"cell_type": "markdown",
"id": "1f2da88f",
"metadata": {},
"source": [
"##### Evaluate predictions"
]
},
{
"cell_type": "markdown",
"id": "0adb0a11",
"metadata": {},
"source": [
"This loop generates contact predictions for each protein using ESM-2, compares them against the experimentally determined structures, and computes precision metrics across different sequence separation ranges to evaluate model performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "941b4afa",
"metadata": {},
"outputs": [],
"source": [
"predictions = {}\n",
"results = []\n",
"\n",
"for pdb_id in sequences:\n",
" _, sequence = sequences[pdb_id]\n",
" prediction = predict_contacts(sequence, model, tokenizer)\n",
" predictions[pdb_id] = prediction[0]\n",
" \n",
" true_contacts = mx.array(contacts[pdb_id])\n",
" \n",
" min_len = min(prediction.shape[1], true_contacts.shape[0])\n",
" pred_trimmed = prediction[:, :min_len, :min_len]\n",
" true_trimmed = true_contacts[:min_len, :min_len]\n",
" true_trimmed = mx.expand_dims(true_trimmed, axis=0)\n",
" \n",
" metrics = evaluate_prediction(pred_trimmed, true_trimmed)\n",
" result = {\"id\": pdb_id, \"model\": \"ESM-2 (Unsupervised)\"}\n",
" result.update(metrics)\n",
" results.append(result)\n",
"\n",
"results_df = pd.DataFrame(results)\n",
"display(results_df)"
]
},
{
"cell_type": "markdown",
"id": "c5c7418a",
"metadata": {},
"source": [
"The results demonstrate that ESM-2 excels at predicting long-range contacts, with precision scores ranging from 40.9% to 86.4% for residues more than 24 positions apart. Performance is consistently higher for distant contacts compared to local ones. For example, the universal stress protein (5ahw) achieves 86.4% precision for long-range contacts but only 2.4% for local contacts between 3 and 6 residues apart. This trend is observed across all three proteins, with medium-range contacts (1224 residues apart) and short-range contacts (612 residues apart) showing intermediate accuracy. These results suggest that ESM-2 has learned to identify evolutionarily conserved structural motifs that connect distant regions of the sequence, which are often critical for protein fold stability and function."
]
},
{
"cell_type": "markdown",
"id": "487cff51",
"metadata": {},
"source": [
"##### Plot contacts and predictions"
]
},
{
"cell_type": "markdown",
"id": "10291191",
"metadata": {},
"source": [
"This analysis generates contact map visualizations for all three proteins, presenting ESM-2s predictions as heatmaps and overlaying the true experimental contacts as colored dots."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "628efc10",
"metadata": {},
"outputs": [],
"source": [
"proteins = [r['id'] for r in results]\n",
"fig, axes = plt.subplots(figsize=(6 * len(proteins), 6), ncols=len(proteins))\n",
"if len(proteins) == 1:\n",
" axes = [axes]\n",
"\n",
"for ax, pdb_id in zip(axes, proteins):\n",
" prediction = predictions[pdb_id]\n",
" target = contacts[pdb_id]\n",
" \n",
" result = next(r for r in results if r['id'] == pdb_id)\n",
" long_pl = result['long_P@L']\n",
" \n",
" plot_contacts_and_predictions(\n",
" prediction, target, ax=ax, \n",
" title=f\"{pdb_id}: Long Range P@L: {100 * long_pl:.1f}%\"\n",
" )\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "99e1edaf",
"metadata": {},
"source": [
"The contact maps highlight ESM-2s strong ability to detect long-range structural relationships. In each panel, the lower triangle shows model predictions, where darker blue regions indicate high-confidence contacts, and the upper triangle shows the corresponding experimental data. Correct predictions appear as blue dots, forming distinct off-diagonal patterns in 5ahw and 1a3a that capture key global fold interactions. Red dots mark false positives, which are relatively rare, while grey dots represent missed contacts. These missed contacts are notably more frequent in 1xcr, consistent with its lower long-range precision. The dense clusters of blue true positives in 5ahw, compared to the sparser, fragmented patterns in 1xcr, clearly illustrate the variation in predictive performance across proteins."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

12
esm/requirements.txt Normal file
View File

@@ -0,0 +1,12 @@
mlx
torch
transformers
numpy
pandas
seaborn
biopython
biotite
scipy
tqdm
scikit-learn
matplotlib

121
esm/test.py Normal file
View File

@@ -0,0 +1,121 @@
import unittest
import numpy as np
from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
from esm import ESM2
# Paths for MLX and Hugging Face versions of ESM-2
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
HF_PATH = "facebook/esm2_t12_35M_UR50D"
def load_mlx_model():
"""Load MLX ESM-2 model and tokenizer."""
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
return tokenizer, model
def load_hf_model():
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
config = EsmConfig.from_pretrained(
HF_PATH, output_hidden_states=True, output_attentions=True
)
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
return tokenizer, model
class TestESM2(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Load both MLX and HF models/tokenizers once for all tests
cls.mlx_tokenizer, cls.mlx_model = load_mlx_model()
cls.hf_tokenizer, cls.hf_model = load_hf_model()
def test_tokenizer(self):
"""Verify MLX tokenizer matches Hugging Face tokenizer behavior."""
self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer))
sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
]
# Compare batched tokenization (padded sequences)
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
hf_batch = (
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
.cpu()
.numpy()
)
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
self.assertTrue(
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
)
# Compare single-sequence encode/decode
for sequence in sequences:
mlx_tokens = self.mlx_tokenizer.encode(sequence)
hf_tokens = (
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
.cpu()
.numpy()
.tolist()[0]
)
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens),
self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
)
self.assertEqual(
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
" ", ""
),
)
def test_model(self):
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
sequences = [
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
]
for sequence in sequences:
# Tokenize
mlx_tokens = self.mlx_tokenizer.encode(sequence, return_batch_dim=True)
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
# Forward pass
mlx_outputs = self.mlx_model(
mlx_tokens,
repr_layers=[self.mlx_model.num_layers],
need_head_weights=True,
)
hf_outputs = self.hf_model(input_ids=hf_tokens)
# Compare logits
mlx_logits = np.array(mlx_outputs["logits"])
hf_logits = hf_outputs["logits"].detach().cpu().numpy()
self.assertTrue(np.allclose(mlx_logits, hf_logits, rtol=1e-4, atol=1e-4))
# Compare final-layer hidden states
final_layer = self.mlx_model.num_layers
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
self.assertTrue(
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
)
# Compare attentions for final layer
mlx_attentions = np.array(
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
)
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
self.assertTrue(
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
)
if __name__ == "__main__":
unittest.main()

View File

@@ -167,8 +167,9 @@ python dreambooth.py \
path/to/dreambooth/dataset/dog6
```
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
Or you can directly use the pre-processed Hugging Face dataset
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
for fine-tuning.
```shell
python dreambooth.py \
@@ -210,3 +211,71 @@ speed during generation.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
Distributed Computation
------------------------
The FLUX example supports distributed computation during both generation and
training. See the [distributed communication
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
for information on how to set-up MLX for distributed communication. The rest of
this section assumes you can launch distributed MLX programs using `mlx.launch
--hostfile hostfile.json`.
### Distributed Finetuning
Distributed finetuning scales very well with FLUX and all one has to do is
adjust the gradient accumulation and training iterations so that the batch
size remains the same. For instance, to replicate the following training
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```
On 4 machines we simply run
```shell
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 2 \
mlx-community/dreambooth-dog6
```
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
### Distributed Inference
Distributed inference can be divided in two different approaches. The first
approach is the data-parallel approach, where each node generates its own
images and shares them at the end. The second approach is the model-parallel
approach where the model is shared across the nodes and they collaboratively
generate the images.
The `txt2image.py` script will attempt to choose the best approach depending on
how many images are being generated across the nodes. The model-parallel
approach can be forced by passing the argument `--force-shard`.
For better performance in the model-parallel approach we suggest that you use a
[thunderbolt
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
All you have to do once again is use `mlx.launch` as follows
```shell
mlx.launch --verbose --hostfile hostfile.json -- \
python txt2image.py --model schnell \
--n-images 8 \
--image-size 512x512 \
--verbose \
'A photo of an astronaut riding a horse on Mars'
```
for model-parallel generation you may want to also pass `--env
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
reduces the CPU/GPU synchronization overhead.

View File

@@ -178,6 +178,8 @@ class DoubleStreamBlock(nn.Module):
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.sharding_group = None
def __call__(
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
) -> Tuple[mx.array, mx.array]:
@@ -216,18 +218,35 @@ class DoubleStreamBlock(nn.Module):
attn = _attention(q, k, v, pe)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# Project - cat - average - split
txt_attn = self.txt_attn.proj(txt_attn)
img_attn = self.img_attn.proj(img_attn)
if self.sharding_group is not None:
attn = mx.concatenate([txt_attn, img_attn], axis=1)
attn = mx.distributed.all_sum(attn, group=self.sharding_group)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
img = img + img_mod1.gate * img_attn
img_mlp = self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
txt = txt + txt_mod1.gate * txt_attn
txt_mlp = self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
if self.sharding_group is not None:
txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)
# finalize the img/txt blocks
img = img + img_mod2.gate * img_mlp
txt = txt + txt_mod2.gate * txt_mlp
return img, txt

View File

@@ -5,6 +5,7 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.distributed import shard_inplace, shard_linear
from .layers import (
DoubleStreamBlock,
@@ -96,6 +97,47 @@ class Flux(nn.Module):
new_weights[k] = w
return new_weights
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
N = group.size()
if N == 1:
return
for block in self.double_blocks:
block.num_heads //= N
block.img_attn.num_heads //= N
block.txt_attn.num_heads //= N
block.sharding_group = group
block.img_attn.qkv = shard_linear(
block.img_attn.qkv, "all-to-sharded", segments=3, group=group
)
block.txt_attn.qkv = shard_linear(
block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
)
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
block.img_mlp.layers[0] = shard_linear(
block.img_mlp.layers[0], "all-to-sharded", group=group
)
block.txt_mlp.layers[0] = shard_linear(
block.txt_mlp.layers[0], "all-to-sharded", group=group
)
shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)
for block in self.single_blocks:
block.num_heads //= N
block.hidden_size //= N
block.linear1 = shard_linear(
block.linear1,
"all-to-sharded",
segments=[1 / 7, 2 / 7, 3 / 7],
group=group,
)
block.linear2 = shard_linear(
block.linear2, "sharded-to-all", segments=[1 / 5], group=group
)
def __call__(
self,
img: mx.array,

View File

@@ -0,0 +1,109 @@
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def print_zero(group, *args, **kwargs):
if group.rank() == 0:
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
group = mx.distributed.init()
if group.size() > 1:
flux.flow.shard(group)
print_zero(group, "Loading models")
flux.ensure_models_are_loaded()
def print_help():
print_zero(group, "The command list:")
print_zero(group, "- 'q' to exit")
print_zero(group, "- 's HxW' to change the size of the image")
print_zero(group, "- 'n S' to change the number of steps")
print_zero(group, "- 'h' to print this help")
print_zero(group, "FLUX interactive session")
print_help()
seed = 0
size = (512, 512)
latent_size = to_latent_size(size)
steps = 50 if args.model == "dev" else 4
while True:
prompt = input(">> " if group.rank() == 0 else "")
if prompt == "q":
break
if prompt == "h":
print_help()
continue
if prompt.startswith("s "):
size = tuple([int(xi) for xi in prompt[2:].split("x")])
print_zero(group, "Setting the size to", size)
latent_size = to_latent_size(size)
continue
if prompt.startswith("n "):
steps = int(prompt[2:])
print_zero(group, "Setting the steps to", steps)
continue
seed += 1
latents = flux.generate_latents(
prompt,
n_images=1,
num_steps=steps,
latent_size=latent_size,
guidance=4.0,
seed=seed,
)
print_zero(group, "Processing prompt")
mx.eval(next(latents))
print_zero(group, "Generating latents")
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
mx.eval(xt)
print_zero(group, "Generating image")
xt = flux.decode(xt, latent_size)
xt = (xt * 255).astype(mx.uint8)
mx.eval(xt)
im = Image.fromarray(np.array(xt[0]))
im.save(args.output)
print_zero(group, "Saved at", args.output, end="\n\n")

View File

@@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
@@ -62,6 +62,7 @@ if __name__ == "__main__":
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
parser.add_argument("--force-shard", action="store_true")
args = parser.parse_args()
# Load the models
@@ -76,6 +77,24 @@ if __name__ == "__main__":
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
# Figure out what kind of distributed generation we should do
group = mx.distributed.init()
n_images = args.n_images
should_gather = False
if group.size() > 1:
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
flux.flow.shard(group)
else:
n_images //= group.size()
should_gather = True
# If we are sharding we should have the same seed and if we are doing
# data parallel generation we should have different seeds
if args.seed is None:
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
if should_gather:
args.seed = args.seed + group.rank()
if args.preload_models:
flux.ensure_models_are_loaded()
@@ -83,7 +102,7 @@ if __name__ == "__main__":
latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents(
args.prompt,
n_images=args.n_images,
n_images=n_images,
num_steps=args.steps,
latent_size=latent_size,
guidance=args.guidance,
@@ -93,8 +112,8 @@ if __name__ == "__main__":
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
@@ -102,36 +121,42 @@ if __name__ == "__main__":
del flux.clip
# Actual denoising loop
for x_t in tqdm(latents, total=args.steps):
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
mx.eval(x_t)
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
peak_mem_generation = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_decoding = mx.get_peak_memory() / 1024**3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
# Gather them if each node has different images
decoded = mx.concatenate(decoded, axis=0)
if should_gather:
decoded = mx.distributed.all_gather(decoded)
mx.eval(decoded)
if args.save_raw:
*name, suffix = args.output.split(".")
name = ".".join(name)
x = mx.concatenate(decoded, axis=0)
x = decoded
x = (x * 255).astype(mx.uint8)
for i in range(len(x)):
im = Image.fromarray(np.array(x[i]))
im.save(".".join([name, str(i), suffix]))
else:
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = decoded
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
@@ -143,7 +168,7 @@ if __name__ == "__main__":
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
if args.verbose and group.rank() == 0:
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")

View File

@@ -1,47 +0,0 @@
# Contributing to MLX LM
Below are some tips to port LLMs available on Hugging Face to MLX.
Before starting checkout the [general contribution
guidelines](https://github.com/ml-explore/mlx-examples/blob/main/CONTRIBUTING.md).
Next, from this directory, do an editable install:
```shell
pip install -e .
```
Then check if the model has weights in the
[safetensors](https://huggingface.co/docs/safetensors/index) format. If not
[follow instructions](https://huggingface.co/spaces/safetensors/convert) to
convert it.
After that, add the model file to the
[`mlx_lm/models`](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/models)
directory. You can see other examples there. We recommend starting from a model
that is similar to the model you are porting.
Make sure the name of the new model file is the same as the `model_type` in the
`config.json`, for example
[starcoder2](https://huggingface.co/bigcode/starcoder2-7b/blob/main/config.json#L17).
To determine the model layer names, we suggest either:
- Refer to the Transformers implementation if you are familiar with the
codebase.
- Load the model weights and check the weight names which will tell you about
the model structure.
- Look at the names of the weights by inspecting `model.safetensors.index.json`
in the Hugging Face repo.
To add LoRA support edit
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
Finally, add a test for the new modle type to the [model
tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py).
From the `llms/` directory, you can run the tests with:
```shell
python -m unittest discover tests/
```

View File

@@ -1,2 +0,0 @@
include mlx_lm/requirements.txt
recursive-include mlx_lm/ *.py

View File

@@ -1,282 +1,6 @@
## Generate Text with LLMs and MLX
# MOVE NOTICE
The easiest way to get started is to install the `mlx-lm` package:
The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm).
**With `pip`**:
```sh
pip install mlx-lm
```
**With `conda`**:
```sh
conda install -c conda-forge mlx-lm
```
The `mlx-lm` package also has:
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
### Quick Start
To generate text with an LLM use:
```bash
mlx_lm.generate --prompt "Hi!"
```
To chat with an LLM use:
```bash
mlx_lm.chat
```
This will give you a chat REPL that you can use to interact with the LLM. The
chat context is preserved during the lifetime of the REPL.
Commands in `mlx-lm` typically take command line options which let you specify
the model, sampling parameters, and more. Use `-h` to see a list of available
options for a command, e.g.:
```bash
mlx_lm.generate -h
```
### Python API
You can use `mlx-lm` as a module:
```python
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
```
To see a description of all the arguments you can do:
```
>>> help(generate)
```
Check out the [generation
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.
You can convert models using the Python API:
```python
from mlx_lm import convert
repo = "mistralai/Mistral-7B-Instruct-v0.3"
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
convert(repo, quantize=True, upload_repo=upload_repo)
```
This will generate a 4-bit quantized Mistral 7B and upload it to the repo
`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
converted model in the path `mlx_model` by default.
To see a description of all the arguments you can do:
```
>>> help(convert)
```
#### Streaming
For streaming generation, use the `stream_generate` function. This yields
a generation response object.
For example,
```python
from mlx_lm import load, stream_generate
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
model, tokenizer = load(repo)
prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(response.text, end="", flush=True)
print()
```
### Command Line
You can also use `mlx-lm` from the command line with:
```
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
```
This will download a Mistral 7B model from the Hugging Face Hub and generate
text using the given prompt.
For a full list of options run:
```
mlx_lm.generate --help
```
To quantize a model from the command line run:
```
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
```
For more options run:
```
mlx_lm.convert --help
```
You can upload new models to Hugging Face by specifying `--upload-repo` to
`convert`. For example, to upload a quantized Mistral-7B model to the
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
```
mlx_lm.convert \
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
-q \
--upload-repo mlx-community/my-4bit-mistral
```
Models can also be converted and quantized directly in the
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
Face Space.
### Long Prompts and Generations
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
To use the rotating key-value cache pass the argument `--max-kv-size n` where
`n` can be any integer. Smaller values like `512` will use very little RAM but
result in worse quality. Larger values like `4096` or higher will use more RAM
but have better quality.
Caching prompts can substantially speedup reusing the same long context with
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
```bash
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
--prompt-cache-file mistral_prompt.safetensors
```
Then use the cached prompt with `mlx_lm.generate`:
```
mlx_lm.generate \
--prompt-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```
The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.
Prompt caching can also be used in the Python API in order to to avoid
recomputing the prompt. This is useful in multi-turn dialogues or across
requests that use the same context. See the
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
for more usage details.
### Supported Models
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
Here are a few examples of Hugging Face models that work with this example:
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b)
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
and
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
style models should work out of the box.
For some models (such as `Qwen` and `plamo`) the tokenizer requires you to
enable the `trust_remote_code` option. You can do this by passing
`--trust-remote-code` in the command line. If you don't specify the flag
explicitly, you will be prompted to trust remote code in the terminal when
running the model.
For `Qwen` models you must also specify the `eos_token`. You can do this by
passing `--eos-token "<|endoftext|>"` in the command
line.
These options can also be set in the Python API. For example:
```python
model, tokenizer = load(
"qwen/Qwen-7B",
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
)
```
### Large Models
> [!NOTE]
This requires macOS 15.0 or higher to work.
Models which are large relative to the total RAM available on the machine can
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
occupied by the model and cache. This requires macOS 15 or higher to
work.
If you see the following warning message:
> [WARNING] Generating with a model that requires ...
then the model will likely be slow on the given machine. If the model fits in
RAM then it can often be sped up by increasing the system wired memory limit.
To increase the limit, set the following `sysctl`:
```bash
sudo sysctl iogpu.wired_limit_mb=N
```
The value `N` should be larger than the size of the model in megabytes but
smaller than the memory size of the machine.
The package has been removed from the MLX Examples repo. Send new contributions
and issues to the MLX LM repo.

View File

@@ -40,7 +40,7 @@ def generate(
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
prompt_tps = len(prompt) / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")

View File

@@ -19,10 +19,10 @@ class ModelArgs:
rms_norm_eps: float
vocab_size: int
context_length: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
model_type: Optional[str] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@@ -54,7 +54,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
self.repeats = n_heads // n_kv_heads
@@ -66,7 +66,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
@@ -254,7 +254,7 @@ def translate_weight_names(name):
return name
def load(gguf_file: str, repo: str = None):
def load(gguf_file: str, repo: Optional[str] = None):
# If the gguf_file exists, try to load model from it.
# Otherwise try to download and cache from the HF repo
if not Path(gguf_file).exists():

View File

@@ -7,6 +7,7 @@ import glob
import json
import shutil
from pathlib import Path
from typing import Dict
import mlx.core as mx
import mlx.nn as nn
@@ -149,7 +150,8 @@ def quantize(weights, config, args):
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
shard: Dict[str, mx.array] = {}
shard_size = 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)

View File

@@ -23,7 +23,7 @@ class ModelArgs:
n_kv_heads: int
norm_eps: float
vocab_size: int
moe: dict = None
moe: dict
class Attention(nn.Module):
@@ -91,7 +91,6 @@ class FeedForward(nn.Module):
class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
@@ -115,7 +114,6 @@ class MOEFeedForward(nn.Module):
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)

View File

@@ -1,368 +0,0 @@
# Fine-Tuning with LoRA or QLoRA
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Mistral
- Llama
- Phi2
- Mixtral
- Qwen2
- Gemma
- OLMo
- MiniCPM
- InternLM2
## Contents
- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
- [Data](#Data)
- [Memory Issues](#Memory-Issues)
## Run
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
```shell
mlx_lm.lora --help
```
Note, in the following the `--model` argument can be any compatible Hugging
Face repo or a local path to a converted model.
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
[example YAML](examples/lora_config.yaml). For example:
```shell
mlx_lm.lora --config /path/to/config.yaml
```
If command-line flags are also used, they will override the corresponding
values in the config.
### Fine-tune
To fine-tune a model use:
```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--data <path_to_data> \
--iters 600
```
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
details on the data format see the section on [Data](#Data).
For example, to fine-tune a Mistral 7B you can use `--model
mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
By default, the adapter config and learned weights are saved in `adapters/`.
You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.
### Evaluate
To compute test set perplexity use:
```shell
mlx_lm.lora \
--model <path_to_model> \
--adapter-path <path_to_adapters> \
--data <path_to_data> \
--test
```
### Generate
For generation use `mlx_lm.generate`:
```shell
mlx_lm.generate \
--model <path_to_model> \
--adapter-path <path_to_adapters> \
--prompt "<your_model_prompt>"
```
## Fuse
You can generate a model fused with the low-rank adapters using the
`mlx_lm.fuse` command. This command also allows you to optionally:
- Upload the fused model to the Hugging Face Hub.
- Export the fused model to GGUF. Note GGUF support is limited to Mistral,
Mixtral, and Llama style models in fp16 precision.
To see supported options run:
```shell
mlx_lm.fuse --help
```
To generate the fused model run:
```shell
mlx_lm.fuse --model <path_to_model>
```
This will by default load the adapters from `adapters/`, and save the fused
model in the path `fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```shell
mlx_lm.fuse \
--model mistralai/Mistral-7B-v0.1 \
--upload-repo mlx-community/my-lora-mistral-7b \
--hf-path mistralai/Mistral-7B-v0.1
```
To export a fused model to GGUF, run:
```shell
mlx_lm.fuse \
--model mistralai/Mistral-7B-v0.1 \
--export-gguf
```
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
can specify the file name with `--gguf-path`.
## Data
The LoRA command expects you to provide a dataset with `--data`. The MLX
Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format.
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
Face.
### Local Datasets
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory.
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
data formats. Here are examples of these formats:
`chat`:
```jsonl
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
```
`tools`:
```jsonl
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
```
<details>
<summary>View the expanded single data tool format</summary>
```jsonl
{
"messages": [
{ "role": "user", "content": "What is the weather in San Francisco?" },
{
"role": "assistant",
"tool_calls": [
{
"id": "call_id",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
}
}
]
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and country, eg. San Francisco, USA"
},
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
},
"required": ["location", "format"]
}
}
}
]
}
```
The format for the `arguments` field in a function varies for different models.
Common formats include JSON strings and dictionaries. The example provided
follows the format used by
[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples)
and [Mistral
AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct).
A dictionary format is used in Hugging Face's [chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example).
Refer to the documentation for the model you are fine-tuning for more details.
</details>
`completions`:
```jsonl
{"prompt": "What is the capital of France?", "completion": "Paris."}
```
For the `completions` data format, a different key can be used for the prompt
and completion by specifying the following in the YAML config:
```yaml
prompt_feature: "input"
completion_feature: "output"
```
Here, `"input"` is the expected key instead of the default `"prompt"`, and
`"output"` is the expected key instead of `"completion"`.
`text`:
```jsonl
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys
in each line not expected by the loader will be ignored.
> [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than
> one example per line and do not split an example across multiple lines.
### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
```
pip install datasets
```
If the Hugging Face dataset is already in a supported format, you can specify
it on the command line. For example, pass `--data mlx-community/wikisql` to
train on the pre-formatted WikiwSQL data.
Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:
```yaml
hf_dataset:
name: "billsum"
prompt_feature: "text"
completion_feature: "summary"
```
- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
[chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
are used. This applies the model's chat template by default. If the model does
not have a chat template, then Hugging Face will use a default. For example,
the final text in the `chat` example above with Hugging Face's default template
becomes:
```text
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello.<|im_end|>
<|im_start|>assistant
How can I assistant you today.<|im_end|>
```
If you are unsure of the format to use, the `chat` or `completions` are good to
start with. For custom requirements on the format of the dataset, use the
`text` format to assemble the content yourself.
## Memory Issues
Fine-tuning a large model with LoRA requires a machine with a decent amount
of memory. Here are some tips to reduce memory use should you need to do so:
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
more details.
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use.
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data.
4. Longer examples require more memory. If it makes sense for your data, one thing
you can do is break your examples into smaller
sequences when making the `{train, valid, test}.jsonl` files.
5. Gradient checkpointing lets you trade-off memory use (less) for computation
(more) by recomputing instead of storing intermediate values needed by the
backward pass. You can use gradient checkpointing by passing the
`--grad-checkpoint` flag. Gradient checkpointing will be more helpful for
larger batch sizes or sequence lengths with smaller or quantized models.
For example, for a machine with 32 GB the following should run reasonably fast:
```
mlx_lm.lora \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--num-layers 4 \
--data wikisql
```
The above command on an M1 Max with 32 GB runs at about 250
tokens-per-second, using the MLX Example
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)

View File

@@ -1,22 +0,0 @@
# Managing Models
You can use `mlx-lm` to manage models downloaded locally in your machine. They
are stored in the Hugging Face cache.
Scan models:
```shell
mlx_lm.manage --scan
```
Specify a `--pattern` to get info on a single or specific set of models:
```shell
mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
```
To delete a model (or multiple models):
```shell
mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
```

View File

@@ -1,50 +0,0 @@
# Model Merging
You can use `mlx-lm` to merge models and upload them to the Hugging
Face hub or save them locally for LoRA fine tuning.
The main command is `mlx_lm.merge`:
```shell
mlx_lm.merge --config config.yaml
```
The merged model will be saved by default in `mlx_merged_model`. To see a
full list of options run:
```shell
mlx_lm.merge --help
```
Here is an example `config.yaml`:
```yaml
models:
- OpenPipe/mistral-ft-optimized-1218
- mlabonne/NeuralHermes-2.5-Mistral-7B
method: slerp
parameters:
t:
- filter: self_attn
value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5
```
The `models` field is a list of Hugging Face repo ids. The first model in the
list is treated as the base model into which the remaining models are merged.
The `method` field is the merging method. Right now `slerp` is the only
supported method.
The `parameters` are the corresponding parameters for the given `method`.
Each parameter is a list with `filter` determining which layer the parameter
applies to and `value` determining the actual value used. The last item in
the list without a `filter` field is the default.
If `value` is a list, it specifies the start and end values for the
corresponding segment of blocks. In the example above, the models have 32
blocks. For blocks 1-8, the layers with `self_attn` in the name will use the
values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16)
will use `np.linspace(0.5, 0.3, 8)`, and so on.

View File

@@ -1,10 +0,0 @@
## Generate Text with MLX and :hugs: Hugging Face
This an example of large language model text generation that can pull models from
the Hugging Face Hub.
For more information on this example, see the [README](../README.md) in the
parent directory.
This package also supports fine tuning with LoRA or QLoRA. For more information
see the [LoRA documentation](LORA.md).

View File

@@ -1,131 +0,0 @@
# HTTP Model Server
You use `mlx-lm` to make an HTTP API for generating text with any supported
model. The HTTP API is intended to be similar to the [OpenAI chat
API](https://platform.openai.com/docs/api-reference).
> [!NOTE]
> The MLX LM server is not recommended for production as it only implements
> basic security checks.
Start the server with:
```shell
mlx_lm.server --model <path_to_model_or_hf_repo>
```
For example:
```shell
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
```
This will start a text generation server on port `8080` of the `localhost`
using Mistral 7B instruct. The model will be downloaded from the provided
Hugging Face repo if it is not already in the local cache.
To see a full list of options run:
```shell
mlx_lm.server --help
```
You can make a request to the model by running:
```shell
curl localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Say this is a test!"}],
"temperature": 0.7
}'
```
### Request Fields
- `messages`: An array of message objects representing the conversation
history. Each message object should have a role (e.g. user, assistant) and
content (the message text).
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
to generate. Defaults to `100`.
- `stream`: (Optional) A boolean indicating if the response should be
streamed. If true, responses are sent as they are generated. Defaults to
false.
- `temperature`: (Optional) A float specifying the sampling temperature.
Defaults to `1.0`.
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
Defaults to `1.0`.
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
Defaults to `1.0`.
- `repetition_context_size`: (Optional) The size of the context window for
applying repetition penalty. Defaults to `20`.
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
values. Defaults to `None`.
- `logprobs`: (Optional) An integer specifying the number of top tokens and
corresponding log probabilities to return for each output in the generated
sequence. If set, this can be any value between 1 and 10, inclusive.
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
If the path is local is must be relative to the directory the server was
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completion", "chat.completion.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models
Use the `v1/models` endpoint to list available models:
```shell
curl localhost:8080/v1/models -H "Content-Type: application/json"
```
This will return a list of locally available models where each model in the
list contains the following fields:
- `id`: The Hugging Face repo id.
- `created`: A time-stamp representing the model creation time.

View File

@@ -1,37 +0,0 @@
### Packaging for PyPI
Install `build` and `twine`:
```
pip install --user --upgrade build
pip install --user --upgrade twine
```
Generate the source distribution and wheel:
```
python -m build
```
> [!warning]
> Use a test server first
#### Test Upload
Upload to test server:
```
python -m twine upload --repository testpypi dist/*
```
Install from test server and check that it works:
```
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm
```
#### Upload
```
python -m twine upload dist/*
```

View File

@@ -1,9 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import os
from ._version import __version__
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
from .utils import convert, generate, load, stream_generate

View File

@@ -1,3 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.21.0"

View File

@@ -1,161 +0,0 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import sys
import time
import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import generate_step, load
DEFAULT_QUANTIZED_KV_START = 5000
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the state of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--max-kv-size",
type=int,
default=None,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--prompt-cache-file",
help="The file to save the prompt cache in",
required=True,
)
parser.add_argument(
"--prompt",
required=True,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and tokenizer.chat_template is not None:
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=False, continue_final_message=True
)
else:
prompt = tokenizer.encode(args.prompt)
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(prompt)
# Process the prompt
start = time.time()
max_msg_len = 0
def callback(processed, total_tokens):
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
nonlocal max_msg_len
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
for _ in generate_step(
y,
model,
max_tokens=0,
prompt_cache=cache,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
prompt_progress_callback=callback,
):
pass
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...")
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
save_prompt_cache(args.prompt_cache_file, cache, metadata)
if __name__ == "__main__":
main()

View File

@@ -1,89 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
print(response.text, flush=True, end="")
print()
if __name__ == "__main__":
main()

View File

@@ -1,62 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from .utils import convert
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
parser.add_argument(
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
)
parser.add_argument(
"--q-group-size", help="Group size for quantization.", type=int, default=64
)
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--dtype",
help="Type to save the non-quantized parameters.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
parser.add_argument(
"-d",
"--dequantize",
help="Dequantize a quantized model.",
action="store_true",
default=False,
)
return parser
def main():
parser = configure_parser()
args = parser.parse_args()
convert(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -1,392 +0,0 @@
# Copyright © 2024 Apple Inc.
"""
Adapted from a PyTorch implementation by David Grangier
"""
import argparse
import json
import logging
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional, Union
import lm_eval
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
from .models.cache import make_prompt_cache
from .utils import load, stream_generate
PAD = 0
def _len_longest_common_prefix(a, b):
l = 0
for item_a, item_b in zip(a, b):
if item_a != item_b:
break
l += 1
return l
def _rstrip_until(s, untils):
"""Limit a string <s> to the first occurrence of any substring in untils."""
l = len(s)
f = [s.find(u) for u in untils]
f = [l if x < 0 else x for x in f]
return s[: min(f)]
def _pad_inputs(
inputs,
maxlen,
genlen=0,
pad_left=False,
pad_multiple=32,
truncate=False,
):
# pad the prompts to the left with at least genlen tokens.
actual_maxlen = max(len(p) for p in inputs) + genlen
if actual_maxlen > maxlen:
if not truncate:
raise ValueError("Inputs are too long.")
else: # drop begining
actual_maxlen = maxlen
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
if pad_multiple > 0:
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
maxlen *= pad_multiple
assert PAD == 0
lr = np.array((1, 0) if pad_left else (0, 1))
return np.stack(
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
axis=0,
)
@register_model("mlxlm")
class MLXLM(LM):
def __init__(
self,
path_or_hf_repo: str,
batch_size: int = 16,
max_tokens: Optional[int] = None,
use_chat_template: Optional[bool] = None,
) -> None:
super().__init__()
self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or (
self.tokenizer.chat_template is not None
)
def _score_fn(self, inputs, tokenize=True, step_size=32):
if tokenize:
inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model)
mask = targets != PAD
scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size):
logits = self._model(inputs[:, i : i + step_size], cache=cache)
log_probs = nn.log_softmax(logits.astype(mx.float32))
score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0]
ig = mask[:, i : i + step_size] * (
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
)
mx.eval(score, ig)
mx.metal.clear_cache()
is_greedy.append(ig)
scores.append(score)
scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1)
return scores, mask.sum(axis=-1), is_greedy
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
sorted_spans = None
if score_spans is not None:
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
results = []
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
batch = sorted_inputs[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
for j in range(len(batch)):
if sorted_spans is None: # full sequence score
mask = mx.arange(scores[j].shape[-1]) < length
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
else: # subsequence score
start, end = sorted_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start
results.append((score.item(), ig.item(), length))
# reorder the outputs
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(results))]
return results
def _tokenize(self, texts):
return [
tuple(
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
)
for t in texts
]
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
`continuation: str`
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
`logprob: float`
The log probability of `continuation`.
`isgreedy`:
Whether `continuation` would be generated by greedy sampling from `context`.
"""
logging.info("Estimating loglikelihood for %d pairs." % len(requests))
# tokenize prefix and prefix + completion for all requests.
tokenized = self._tokenize(
[t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]]
)
# max length (prefix + completion) and longest common prefix per question.
length_stats = {}
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8))
length_stats[prefix] = (
max(max_completed_l, len(completed)),
min(min_prefix_l, _len_longest_common_prefix(prefix, completed)),
)
# truncate requests for completed sequences longer than model context.
shortened = []
completion_spans = []
long_completions = 0
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, prefix_l = length_stats[prefix]
# compute truncation length
truncation = max(0, max_completed_l - self._max_tokens - 1)
prefix_l = prefix_l - truncation
if prefix_l <= 0:
# completion too long, prefix is eliminated for some requests.
long_completions += 1
truncation = max(0, len(completed) - self._max_tokens - 1)
prefix_l = 1
# truncate the completed sequence
completed = completed[truncation:]
shortened.append(completed)
# scores do not include initial bos, substract 1 to span bounds
completion_spans.append((prefix_l - 1, len(completed) - 1))
if long_completions > 0:
logging.info(
f"Prefix eliminated for {long_completions} requests with "
+ "completion longer than context."
)
# model scoring, returns num_requests x (logp, is_greedy, length).
results = self._loglikelihood(
shortened,
score_spans=completion_spans,
tokenize=False,
)
return [(r[0], r[1] == r[2]) for r in results]
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context,).
string: str
String for which we are computing overall loglikelihood
:return: list[tuple[float]]
A list of tuples (logprob,)
logprob: float
The log probability of `context` conditioned on the EOT token.
"""
logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests)
)
inputs = [req.args[0] for req in requests]
return [t[0] for t in self._loglikelihood(inputs)]
def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until).
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list[str]
A list of strings continuation
continuation: str
The generated continuation.
"""
logging.info("Generating continuation for %d sequences." % len(requests))
contexts, options = zip(*[req.args for req in requests])
# contrary to the doc the second element of the tuple contains
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
keys = list(options[0].keys())
assert "until" in keys
untils = [x["until"] for x in options]
completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context)
max_tokens = min(
self._max_tokens,
self.tokenizer.model_max_length - len(context),
)
text = ""
for response in stream_generate(
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
):
text += response.text
if any(u in text for u in until):
text = _rstrip_until(text, until)
completions.append(text)
break
else:
completions.append(text)
return completions
def main():
parser = argparse.ArgumentParser(
"Evaluate an MLX model using lm-evaluation-harness."
)
parser.add_argument("--model", help="Model to evaluate", required=True)
parser.add_argument("--tasks", nargs="+", required=True)
parser.add_argument(
"--output-dir", default=".", help="Output directory for result files."
)
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
parser.add_argument("--num-shots", type=int, default=0, help="Number of shots")
parser.add_argument(
"--max-tokens",
type=int,
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
)
parser.add_argument(
"--limit",
default=1.0,
help="Limit the number of examples per task.",
type=float,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument(
"--fewshot-as-multiturn",
action="store_true",
help="Whether to provide the fewshot examples as a multiturn "
"conversation or a single user turn.",
default=False,
)
parser.add_argument(
"--apply-chat-template",
action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
)
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Silence tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
mx.random.seed(args.seed)
lm = MLXLM(
args.model,
batch_size=args.batch_size,
max_tokens=args.max_tokens,
use_chat_template=args.apply_chat_template,
)
results = lm_eval.simple_evaluate(
model=lm,
tasks=args.tasks,
fewshot_as_multiturn=args.fewshot_as_multiturn,
apply_chat_template=lm.use_chat_template,
num_fewshot=args.num_shots,
limit=args.limit,
random_seed=args.seed,
numpy_random_seed=args.seed,
torch_random_seed=args.seed,
fewshot_random_seed=args.seed,
)
model_name = args.model.replace("/", "_")
task_names = "_".join(args.tasks)
ver = version("lm_eval")
filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json"
output_path = output_dir / filename
output_path.write_text(json.dumps(results["results"], indent=4))
print("Results:")
for result in results["results"].values():
print(json.dumps(result, indent=4))

View File

@@ -1,48 +0,0 @@
# Copyright © 2024 Apple Inc.
"""
An example of a multi-turn chat with prompt caching.
"""
from mlx_lm import generate, load
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
# Make the initial prompt cache for the model
prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
prompt_cache=prompt_cache,
)
# Save the prompt cache to disk to reuse it at a later time
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
# Load the prompt cache from disk
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")

View File

@@ -1,33 +0,0 @@
# Copyright © 2024 Apple Inc.
from mlx_lm import generate, load
# Specify the checkpoint
checkpoint = "mistralai/Mistral-7B-Instruct-v0.3"
# Load the corresponding model and tokenizer
model, tokenizer = load(path_or_hf_repo=checkpoint)
# Specify the prompt and conversation history
prompt = "Why is the sky blue?"
conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template(
conversation=conversation, add_generation_prompt=True
)
# Specify the maximum number of tokens
max_tokens = 1_000
# Specify if tokens and timing information will be printed
verbose = True
# Generate a response with the specified settings
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
verbose=verbose,
)

View File

@@ -1,80 +0,0 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
# Whether or not to train (boolean)
train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
# The PRNG seed
seed: 0
# Number of layers to fine-tune
num_layers: 16
# Minibatch size.
batch_size: 4
# Iterations to train for.
iters: 1000
# Number of validation batches, -1 uses the entire validation set.
val_batches: 25
# Adam learning rate.
learning_rate: 1e-5
# Number of training steps between loss reporting.
steps_per_report: 10
# Number of training steps between validations.
steps_per_eval: 200
# Load path to resume training with the given adapter weights.
resume_adapter_file: null
# Save/load path for the trained adapter weights.
adapter_path: "adapters"
# Save the model every N iterations.
save_every: 100
# Evaluate on the test set after training
test: false
# Number of test set batches, -1 uses the entire test set.
test_batches: 100
# Maximum sequence length.
max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: false
# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.
# These will be applied for the last lora_layers
keys: ["self_attn.q_proj", "self_attn.v_proj"]
rank: 8
scale: 20.0
dropout: 0.0
# Schedule can only be specified in a config file, uncomment to use.
#lr_schedule:
# name: cosine_decay
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
#hf_dataset:
# name: "billsum"
# train_split: "train[:1000]"
# valid_split: "train[-100:]"
# prompt_feature: "text"
# completion_feature: "summary"

View File

@@ -1,11 +0,0 @@
models:
- OpenPipe/mistral-ft-optimized-1218
- mlabonne/NeuralHermes-2.5-Mistral-7B
method: slerp
parameters:
t:
- filter: self_attn
value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5

View File

@@ -1,127 +0,0 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
```
mlx.launch \
--hostfile /path/to/hosts.txt \
--backend mpi \
/path/to/pipeline_generate.py \
--prompt "hello world"
```
Make sure you can run MLX over MPI on two hosts. For more information see the
documentation:
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
"""
import argparse
import json
from pathlib import Path
import mlx.core as mx
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx_lm import load, stream_generate
from mlx_lm.utils import load_model, load_tokenizer
def download(repo: str, allow_patterns: list[str]) -> Path:
return Path(
snapshot_download(
repo,
allow_patterns=allow_patterns,
)
)
def shard_and_load(repo):
# Get model path with everything but weight safetensors
model_path = download(
args.model,
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
# Lazy load and shard model to figure out
# which weights we need
model, _ = load_model(model_path, lazy=True, strict=False)
group = mx.distributed.init(backend="mpi")
rank = group.rank()
model.model.pipeline(group)
# Figure out which files we need for the local shard
with open(model_path / "model.safetensors.index.json", "r") as fid:
weight_index = json.load(fid)["weight_map"]
local_files = set()
for k, _ in tree_flatten(model.parameters()):
local_files.add(weight_index[k])
# Download weights for local shard
download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
group = mx.distributed.init(backend="mpi")
rank = group.rank()
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
model, tokenizer = shard_and_load(args.model)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model, tokenizer, prompt, max_tokens=args.max_tokens
):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@@ -1,130 +0,0 @@
import argparse
import glob
import shutil
from pathlib import Path
from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import dequantize, load_adapters
from .utils import (
fetch_from_hub,
get_model_path,
save_config,
save_weights,
upload_to_hub,
)
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Fuse fine-tuned adapters into the base model."
)
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--save-path",
default="fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
"--adapter-path",
type=str,
default="adapters",
help="Path to the trained adapter weights and config.",
)
parser.add_argument(
"--hf-path",
type=str,
default=None,
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
parser.add_argument(
"--de-quantize",
help="Generate a de-quantized model.",
action="store_true",
)
parser.add_argument(
"--export-gguf",
help="Export model weights in GGUF format.",
action="store_true",
)
parser.add_argument(
"--gguf-path",
help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.",
default="ggml-model-f16.gguf",
type=str,
)
return parser.parse_args()
def main() -> None:
print("Loading pretrained model")
args = parse_arguments()
model_path = get_model_path(args.model)
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
model = load_adapters(model, args.adapter_path)
fused_linears = [
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
]
if fused_linears:
model.update_modules(tree_unflatten(fused_linears))
if args.de_quantize:
print("De-quantizing model")
model = dequantize(model)
weights = dict(tree_flatten(model.parameters()))
save_path = Path(args.save_path)
save_weights(save_path, weights)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, save_path)
tokenizer.save_pretrained(save_path)
if args.de_quantize:
config.pop("quantization", None)
save_config(config, config_path=save_path / "config.json")
if args.export_gguf:
model_type = config["model_type"]
if model_type not in ["llama", "mixtral", "mistral"]:
raise ValueError(
f"Model type {model_type} not supported for GGUF conversion."
)
convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
if args.upload_repo is not None:
hf_path = args.hf_path or (
args.model if not Path(args.model).exists() else None
)
if hf_path is None:
raise ValueError(
"Must provide original Hugging Face repo to upload local model."
)
upload_to_hub(args.save_path, args.upload_repo, hf_path)
if __name__ == "__main__":
main()

View File

@@ -1,257 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import sys
import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache
from .sample_utils import make_sampler
from .utils import generate, load
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
def str2bool(string):
return string.lower() not in ["false", "f"]
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="LLM inference script")
parser.add_argument(
"--model",
type=str,
help=(
"The path to the local model directory or Hugging Face repo. "
f"If no model is specified, then {DEFAULT_MODEL} is used."
),
default=None,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--extra-eos-token",
type=str,
default=(),
nargs="+",
help="Add tokens in the list of eos tokens that stop generation.",
)
parser.add_argument(
"--system-prompt",
default=None,
help="System prompt to be used for the chat template",
)
parser.add_argument(
"--prompt",
"-p",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument(
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
)
parser.add_argument(
"--min-tokens-to-keep",
type=int,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--verbose",
type=str2bool,
default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
)
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--prompt-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
parser.add_argument(
"--draft-model",
type=str,
help="A model to be used for speculative decoding.",
default=None,
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
if using_cache:
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file,
return_metadata=True,
)
if isinstance(prompt_cache[0], QuantizedKVCache):
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
raise ValueError(
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
)
if args.kv_group_size != prompt_cache[0].group_size:
raise ValueError(
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
)
# Building tokenizer_config
tokenizer_config = (
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
tokenizer_config["trust_remote_code"] = True
model_path = args.model
if using_cache:
if model_path is None:
model_path = metadata["model"]
elif model_path != metadata["model"]:
raise ValueError(
f"Providing a different model ({model_path}) than that "
f"used to create the prompt cache ({metadata['model']}) "
"is an error."
)
model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt
if not args.ignore_chat_template and tokenizer.chat_template is not None:
if args.system_prompt is not None:
messages = [{"role": "system", "content": args.system_prompt}]
else:
messages = []
messages.append({"role": "user", "content": prompt})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if using_cache:
messages[-1]["content"] = "<query>"
test_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
if args.draft_model is not None:
draft_model, draft_tokenizer = load(args.draft_model)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
raise ValueError("Draft model tokenizer does not match model tokenizer.")
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
if not args.verbose:
print(response)
if __name__ == "__main__":
main()

View File

@@ -1,314 +0,0 @@
import re
from enum import IntEnum
from pathlib import Path
from typing import Iterable, Optional, Set, Tuple, Union
import mlx.core as mx
from transformers import AutoTokenizer
class TokenType(IntEnum):
NORMAL = 1
UNKNOWN = 2
CONTROL = 3
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
class GGMLFileType(IntEnum):
GGML_TYPE_F16 = 1
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
class HfVocab:
def __init__(
self,
fname_tokenizer: Path,
fname_added_tokens: Optional[Union[Path, None]] = None,
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
fname_tokenizer,
cache_dir=fname_tokenizer,
local_files_only=True,
)
self.added_tokens_list = []
self.added_tokens_dict = dict()
self.added_tokens_ids = set()
for tok, tokidx in sorted(
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
):
if tokidx >= self.tokenizer.vocab_size:
self.added_tokens_list.append(tok)
self.added_tokens_dict[tok] = tokidx
self.added_tokens_ids.add(tokidx)
self.specials = {
tok: self.tokenizer.get_vocab()[tok]
for tok in self.tokenizer.all_special_tokens
}
self.special_ids = set(self.tokenizer.all_special_ids)
self.vocab_size_base = self.tokenizer.vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
reverse_vocab = {
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
}
for token_id in range(self.vocab_size_base):
if token_id in self.added_tokens_ids:
continue
token_text = reverse_vocab[token_id]
yield token_text, self.get_token_score(token_id), self.get_token_type(
token_id, token_text, self.special_ids
)
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
def get_token_score(self, token_id: int) -> float:
return -1000.0
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
score = self.get_token_score(self.specials[text])
else:
toktype = TokenType.USER_DEFINED
score = -1000.0
yield text, score, toktype
def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
yield from self.hf_tokens()
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
@staticmethod
def load(path: Path) -> "HfVocab":
added_tokens_path = path.parent / "added_tokens.json"
return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None)
def translate_weight_names(name):
name = name.replace("model.layers.", "blk.")
# for mixtral gate
name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
# for mixtral experts ffns
pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight"
replacement = r"ffn_gate.\1.weight"
name = re.sub(pattern, replacement, name)
pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight"
replacement = r"ffn_down.\1.weight"
name = re.sub(pattern, replacement, name)
pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight"
replacement = r"ffn_up.\1.weight"
name = re.sub(pattern, replacement, name)
name = name.replace("mlp.gate_proj", "ffn_gate")
name = name.replace("mlp.down_proj", "ffn_down")
name = name.replace("mlp.up_proj", "ffn_up")
name = name.replace("self_attn.q_proj", "attn_q")
name = name.replace("self_attn.k_proj", "attn_k")
name = name.replace("self_attn.v_proj", "attn_v")
name = name.replace("self_attn.o_proj", "attn_output")
name = name.replace("input_layernorm", "attn_norm")
name = name.replace("post_attention_layernorm", "ffn_norm")
name = name.replace("model.embed_tokens", "token_embd")
name = name.replace("model.norm", "output_norm")
name = name.replace("lm_head", "output")
return name
def permute_weights(weights, n_head, n_head_kv=None):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
reshaped = weights.reshape(
n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
)
swapped = reshaped.swapaxes(1, 2)
final_shape = weights.shape
return swapped.reshape(final_shape)
def prepare_metadata(config, vocab):
metadata = {
"general.name": "llama",
"llama.context_length": (
mx.array(config["max_position_embeddings"], dtype=mx.uint32)
if config.get("max_position_embeddings") is not None
else None
),
"llama.embedding_length": (
mx.array(config["hidden_size"], dtype=mx.uint32)
if config.get("hidden_size") is not None
else None
),
"llama.block_count": (
mx.array(config["num_hidden_layers"], dtype=mx.uint32)
if config.get("num_hidden_layers") is not None
else None
),
"llama.feed_forward_length": (
mx.array(config["intermediate_size"], dtype=mx.uint32)
if config.get("intermediate_size") is not None
else None
),
"llama.rope.dimension_count": (
mx.array(
config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32
)
if config.get("hidden_size") is not None
and config.get("num_attention_heads") is not None
else None
),
"llama.attention.head_count": (
mx.array(config["num_attention_heads"], dtype=mx.uint32)
if config.get("num_attention_heads") is not None
else None
),
"llama.attention.head_count_kv": (
mx.array(
config.get("num_key_value_heads", config["num_attention_heads"]),
dtype=mx.uint32,
)
if config.get("num_attention_heads") is not None
else None
),
"llama.expert_count": (
mx.array(config.get("num_local_experts", None), dtype=mx.uint32)
if config.get("num_local_experts") is not None
else None
),
"llama.expert_used_count": (
mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32)
if config.get("num_experts_per_tok") is not None
else None
),
"llama.attention.layer_norm_rms_epsilon": (
mx.array(config.get("rms_norm_eps", 1e-05))
if config.get("rms_norm_eps") is not None
else None
),
"llama.rope.freq_base": (
mx.array(config.get("rope_theta", 10000), dtype=mx.float32)
if config.get("rope_theta") is not None
else None
),
}
rope_scaling = config.get("rope_scaling")
if rope_scaling is not None and (typ := rope_scaling.get("type")):
rope_factor = rope_scaling.get("factor")
f_rope_scale = rope_factor
if typ == "linear":
rope_scaling_type = "linear"
metadata["llama.rope.scaling.type"] = rope_scaling_type
metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale)
metadata["general.file_type"] = mx.array(
GGMLFileType.GGML_TYPE_F16.value,
dtype=mx.uint32,
)
metadata["general.quantization_version"] = mx.array(
GGMLFileType.GGML_TYPE_F16.value,
dtype=mx.uint32,
)
metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1]
metadata["general.architecture"] = "llama"
metadata["general.alignment"] = mx.array(32, dtype=mx.uint32)
# add metadata for vocab
metadata["tokenizer.ggml.model"] = "llama"
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype.value)
assert len(tokens) == vocab.vocab_size
metadata["tokenizer.ggml.tokens"] = tokens
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
if vocab.tokenizer.bos_token_id is not None:
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
vocab.tokenizer.bos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.eos_token_id is not None:
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
vocab.tokenizer.eos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.unk_token_id is not None:
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
vocab.tokenizer.unk_token_id, dtype=mx.uint32
)
metadata = {k: v for k, v in metadata.items() if v is not None}
return metadata
def convert_to_gguf(
model_path: Union[str, Path],
weights: dict,
config: dict,
output_file_path: str,
):
if isinstance(model_path, str):
model_path = Path(model_path)
quantization = config.get("quantization", None)
if quantization:
raise NotImplementedError(
"Conversion of quantized models is not yet supported."
)
print("Converting to GGUF format")
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention
weights = {
k: (
permute_weights(
v, config["num_attention_heads"], config["num_attention_heads"]
)
if "self_attn.q_proj.weight" in k
else (
permute_weights(
v, config["num_attention_heads"], config["num_key_value_heads"]
)
if "self_attn.k_proj.weight" in k
else v
)
)
for k, v in weights.items()
}
# rename weights for gguf format
weights = {translate_weight_names(k): v for k, v in weights.items()}
if not (model_path / "tokenizer.json").exists():
raise ValueError("Tokenizer json not found")
vocab = HfVocab.load(model_path)
metadata = prepare_metadata(config, vocab)
weights = {
k: (
v.astype(mx.float32).astype(mx.float16)
if v.dtype == mx.bfloat16
else v.astype(mx.float32) if "norm" in k else v
)
for k, v in weights.items()
}
output_file_path = output_file_path
mx.save_gguf(output_file_path, weights, metadata)
print(f"Converted GGUF model saved as: {output_file_path}")

View File

@@ -1,299 +0,0 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import re
import types
from pathlib import Path
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import yaml
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
load_adapters,
print_trainable_parameters,
)
from .utils import load, save_config
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
"num_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
"learning_rate": 1e-5,
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_path": "adapters",
"save_every": 100,
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
)
# Training args
parser.add_argument(
"--train",
action="store_true",
help="Do training",
default=None,
)
parser.add_argument(
"--data",
type=str,
help=(
"Directory with {train, valid, test}.jsonl files or the name "
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
),
)
parser.add_argument(
"--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--num-layers",
type=int,
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
)
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
parser.add_argument("--iters", type=int, help="Iterations to train for.")
parser.add_argument(
"--val-batches",
type=int,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
parser.add_argument(
"--steps-per-report",
type=int,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps-per-eval",
type=int,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training from the given fine-tuned weights.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Save/load path for the fine-tuned weights.",
)
parser.add_argument(
"--save-every",
type=int,
help="Save the model every N iterations.",
)
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
default=None,
)
parser.add_argument(
"--test-batches",
type=int,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument(
"--max-seq-length",
type=int,
help="Maximum sequence length.",
)
parser.add_argument(
"-c",
"--config",
type=str,
help="A YAML configuration file with the training options",
)
parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser
def train_model(
args,
model: nn.Module,
tokenizer: TokenizerWrapper,
train_set,
valid_set,
training_callback: TrainingCallback = None,
):
model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process
linear_to_lora_layers(
model,
args.num_layers,
args.lora_parameters,
use_dora=(args.fine_tune_type == "dora"),
)
else:
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
# Resume from weights if provided
if args.resume_adapter_file is not None:
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load(args.model)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
load_adapters(model, args.adapter_path)
elif args.train:
print("Training")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")
if args.test:
print("Testing")
evaluate_model(args, model, tokenizer, test_set)
def main():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = build_parser()
args = parser.parse_args()
config = args.config
args = vars(args)
if config:
print("Loading configuration file", config)
with open(config, "r") as file:
config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments
for k, v in config.items():
if args.get(k, None) is None:
args[k] = v
# Update defaults for unspecified parameters
for k, v in CONFIG_DEFAULTS.items():
if args.get(k, None) is None:
args[k] = v
run(types.SimpleNamespace(**args))
if __name__ == "__main__":
main()

View File

@@ -1,124 +0,0 @@
import argparse
from typing import List, Union
from huggingface_hub import scan_cache_dir
from transformers.commands.user import tabulate
def ask_for_confirmation(message: str) -> bool:
"""Ask user for confirmation with Y/N prompt.
Returns True for Y/yes, False for N/no/empty."""
y = ("y", "yes", "1")
n = ("n", "no", "0", "")
full_message = f"{message} (y/n) "
while True:
answer = input(full_message).lower()
if answer in y:
return True
if answer in n:
return False
print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
def main():
parser = argparse.ArgumentParser(description="MLX Model Cache.")
parser.add_argument(
"--scan",
action="store_true",
help="Scan Hugging Face cache for mlx models.",
)
parser.add_argument(
"--delete",
action="store_true",
help="Delete models matching the given pattern.",
)
parser.add_argument(
"--pattern",
type=str,
help="Model repos contain the pattern.",
default="mlx",
)
args = parser.parse_args()
if args.scan:
print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
hf_cache_info = scan_cache_dir()
print(
tabulate(
rows=[
[
repo.repo_id,
repo.repo_type,
"{:>12}".format(repo.size_on_disk_str),
repo.nb_files,
repo.last_accessed_str,
repo.last_modified_str,
str(repo.repo_path),
]
for repo in sorted(
hf_cache_info.repos, key=lambda repo: repo.repo_path
)
if args.pattern in repo.repo_id
],
headers=[
"REPO ID",
"REPO TYPE",
"SIZE ON DISK",
"NB FILES",
"LAST_ACCESSED",
"LAST_MODIFIED",
"LOCAL PATH",
],
)
)
if args.delete:
print(f'Deleting models matching pattern "{args.pattern}"')
hf_cache_info = scan_cache_dir()
repos = [
repo
for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
if args.pattern in repo.repo_id
]
if repos:
print("\nFound the following models:")
print(
tabulate(
rows=[
[
repo.repo_id,
repo.size_on_disk_str, # Added size information
str(repo.repo_path),
]
for repo in repos
],
headers=[
"REPO ID",
"SIZE", # Added size header
"LOCAL PATH",
],
)
)
confirmed = ask_for_confirmation(
"\nAre you sure you want to delete these models?"
)
if confirmed:
for model_info in repos:
print(f"\nDeleting {model_info.repo_id}...")
for revision in sorted(
model_info.revisions, key=lambda revision: revision.commit_hash
):
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
strategy.execute()
print("\nModel(s) deleted successfully.")
else:
print("\nDeletion cancelled - no changes made.")
else:
print(f'No models found matching pattern "{args.pattern}"')
if __name__ == "__main__":
main()

View File

@@ -1,172 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import glob
import shutil
from pathlib import Path
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import yaml
from mlx.utils import tree_flatten, tree_map
from .utils import (
fetch_from_hub,
get_model_path,
save_config,
save_weights,
upload_to_hub,
)
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(description="Merge multiple models.")
parser.add_argument("--config", type=str, help="Path to the YAML config.")
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_merged_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser
def slerp(t, w1, w2, eps=1e-5):
"""
Spherical linear interpolation
Args:
t (float): Interpolation weight in [0.0, 1.0]
w1 (mx.array): First input
w2 (mx.array): Second input
eps (float): Constant for numerical stability
Returns:
mx.array: Interpolated result
"""
t = float(t)
if t == 0:
return w1
elif t == 1:
return w2
# Normalize
v1 = w1 / mx.linalg.norm(w1)
v2 = w2 / mx.linalg.norm(w2)
# Angle
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
theta = mx.arccos(dot)
sin_theta = mx.sin(theta + eps)
s1 = mx.sin(theta * (1 - t)) / sin_theta
s2 = mx.sin(theta * t) / sin_theta
return s1 * w1 + s2 * w2
def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
method = config.get("method", None)
if method != "slerp":
raise ValueError(f"Merge method {method} not supported")
num_layers = len(model.layers)
def unpack_values(vals):
if isinstance(vals, (int, float)):
return np.full(num_layers, vals)
bins = len(vals) - 1
sizes = [num_layers // bins] * bins
sizes[-1] = num_layers - sum(sizes[:-1])
return np.concatenate(
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
)
param_list = config["parameters"]["t"]
params = {}
filter_keys = set()
for pl in param_list[:-1]:
params[pl["filter"]] = unpack_values(pl["value"])
filter_keys.add(pl["filter"])
default = unpack_values(param_list[-1]["value"])
for e in range(num_layers):
bl = base_model.layers[e]
l = model.layers[e]
base_weights = bl.parameters()
weights = l.parameters()
for k, w1 in base_weights.items():
w2 = weights[k]
t = params.get(k, default)[e]
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
base_model.update(base_weights)
def merge(
config: str,
mlx_path: str = "mlx_model",
upload_repo: Optional[str] = None,
):
with open(config, "r") as fid:
merge_conf = yaml.safe_load(fid)
print("[INFO] Loading")
model_paths = merge_conf.get("models", [])
if len(model_paths) < 2:
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
# Load all models
base_hf_path = model_paths[0]
base_path = get_model_path(base_hf_path)
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
models = []
for mp in model_paths[1:]:
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
base_type = base_config["model_type"]
model_type = model_config["model_type"]
if base_type != model_type:
raise ValueError(
f"Can only merge models of the same type,"
f" but got {base_type} and {model_type}."
)
models.append(model)
# Merge models into base model
for m in models:
merge_models(base_model, m, merge_conf)
# Save base model
mlx_path = Path(mlx_path)
weights = dict(tree_flatten(base_model.parameters()))
del models, base_model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(base_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
save_config(config, config_path=mlx_path / "config.json")
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, base_hf_path)
def main():
parser = configure_parser()
args = parser.parse_args()
merge(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -1,121 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import inspect
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
from mlx.utils import tree_map
from .cache import QuantizedKVCache
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def create_causal_mask(
N: int,
offset: int = 0,
window_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset)
window_size = c.max_size
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None
return mask
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = 64,
bits: int = 8,
) -> mx.array:
B, n_q_heads, L, D = queries.shape
n_kv_heads = q_keys[0].shape[-3]
n_repeats = n_q_heads // n_kv_heads
queries *= scale
if n_repeats > 1:
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
scores = mx.quantized_matmul(
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = mx.quantized_matmul(
scores, *q_values, transpose=False, group_size=group_size, bits=bits
)
if n_repeats > 1:
out = mx.reshape(out, (B, n_q_heads, L, D))
return out
def scaled_dot_product_attention(
queries,
keys,
values,
cache,
scale: float,
mask: Optional[mx.array],
) -> mx.array:
if isinstance(cache, QuantizedKVCache):
return quantized_scaled_dot_product_attention(
queries,
keys,
values,
scale=scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
return mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask
)

View File

@@ -1,438 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = [c.state for c in cache]
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_info, metadata, cache_classes]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if return_metadata:
return cache, metadata
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]
class _BaseCache:
@property
def state(self):
return []
@state.setter
def state(self, v):
if v is not None and v:
raise ValueError("This cache has no state but a state was set.")
@property
def meta_state(self):
return ""
@meta_state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
def is_trimmable(self):
return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 8):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
v_head_dim = values.shape[-1]
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
el_per_int = 8 * mx.uint32.size // self.bits
new_steps = (self.step + num_steps - 1) // self.step * self.step
shape = (B, n_kv_heads, new_steps)
def init_quant(dim):
return (
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
)
def expand_quant(x):
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
return mx.concatenate([x, new_x], axis=-2)
if self.keys is not None:
if prev % self.step != 0:
self.keys, self.values = tree_map(
lambda x: x[..., :prev, :], (self.keys, self.values)
)
self.keys, self.values = tree_map(
expand_quant, (self.keys, self.values)
)
else:
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
self.offset += num_steps
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
@property
def state(self):
if self.offset == self.keys[0].shape[2]:
return self.keys, self.values
else:
return tree_map(
lambda x: x[..., : self.offset, :], (self.keys, self.values)
)
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
@meta_state.setter
def meta_state(self, v):
self.step, self.offset, self.group_size, self.bits = map(int, v)
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
if self.offset == self.keys.shape[2]:
return self.keys, self.values
else:
return (
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
@state.setter
def state(self, v):
self.keys, self.values = v
self.offset = self.keys.shape[2]
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
quant_cache.offset = self.offset
if self.keys is not None:
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(
self.values, group_size=group_size, bits=bits
)
return quant_cache
class RotatingKVCache(_BaseCache):
def __init__(self, max_size=None, keep=0, step=256):
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def _temporal_order(self, v):
"""
Rearrange the cache into temporal order, slicing off the end if unused.
"""
if self._idx == v.shape[2]:
return v
elif self._idx < self.offset:
return mx.concatenate(
[
v[..., : self.keep, :],
v[..., self._idx :, :],
v[..., self.keep : self._idx, :],
],
axis=2,
)
else:
return v[..., : self._idx, :]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
# Put the keys/values in temporal order to
# preserve context
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, n_kv_heads, S, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
v_head_dim = values.shape[3]
new_size = min(self.step, self.max_size - prev)
k_shape = (B, n_kv_heads, new_size, k_head_dim)
v_shape = (B, n_kv_heads, new_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
if keys.shape[2] == 1:
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
if self.offset < self.keys.shape[2]:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
else:
return self.keys, self.values
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
@meta_state.setter
def meta_state(self, v):
self.keep, self.max_size, self.step, self.offset, self._idx = map(
int,
v,
)
def is_trimmable(self):
return self.offset < self.max_size
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
self._idx -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v

View File

@@ -1,195 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 8192
num_hidden_layers: int = 40
intermediate_size: int = 22528
num_attention_heads: int = 64
num_key_value_heads: int = 64
rope_theta: float = 8000000.0
vocab_size: int = 256000
layer_norm_eps: float = 1e-05
logit_scale: float = 0.0625
attention_bias: bool = False
layer_norm_bias: bool = False
use_qk_norm: bool = False
class LayerNorm2D(nn.Module):
def __init__(self, d1, d2, eps):
super().__init__()
self.weight = mx.zeros((d1, d2))
self.eps = eps
def __call__(self, x):
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // args.num_attention_heads
self.scale = head_dim**-0.5
attetion_bias = args.attention_bias
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
self.use_qk_norm = args.use_qk_norm
if self.use_qk_norm:
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
self.k_norm = LayerNorm2D(
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1)
keys = keys.reshape(B, L, self.n_kv_heads, -1)
if self.use_qk_norm:
queries = self.q_norm(queries)
keys = self.k_norm(keys)
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x
class CohereModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = CohereModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out
@property
def layers(self):
return self.model.layers

View File

@@ -1,206 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 4096
head_dim: int = 128
num_hidden_layers: int = 32
intermediate_size: int = 14336
num_attention_heads: int = 32
num_key_value_heads: int = 8
rope_theta: float = 50000.0
vocab_size: int = 256000
layer_norm_eps: float = 1e-05
logit_scale: float = 0.0625
attention_bias: bool = False
layer_norm_bias: bool = False
sliding_window: int = 4096
sliding_window_pattern: int = 4
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.args = args
self.layer_idx = layer_idx
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
if (head_dim * n_heads) != dim:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
f" and `num_heads`: {n_heads})."
)
self.scale = head_dim**-0.5
attetion_bias = args.attention_bias
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
# Apply RoPE only if sliding window is enabled
if self.use_sliding_window:
if cache is None:
queries = self.rope(queries)
keys = self.rope(keys)
else:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
if self.use_sliding_window and mask is not None:
key_len = keys.shape[-2]
if mask.shape[-1] != key_len:
mask = mask[..., -key_len:]
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x
class CohereModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=i)
for i in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if cache is None:
cache = [None] * len(self.layers)
if mask is None:
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = CohereModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out
def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if (
i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1
):
caches.append(KVCache())
else:
caches.append(
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
)
return caches
@property
def layers(self):
return self.model.layers

View File

@@ -1,254 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
d_model: int
ffn_config: dict
attn_config: dict
n_layers: int
n_heads: int
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_heads = args.n_heads
self.d_model = args.d_model
self.head_dim = args.d_model // args.n_heads
self.num_key_value_heads = args.attn_config["kv_n_heads"]
self.clip_qkv = args.attn_config["clip_qkv"]
self.rope_theta = args.attn_config["rope_theta"]
self.scale = self.head_dim**-0.5
self.Wqkv = nn.Linear(
args.d_model,
(self.num_key_value_heads * 2 + self.num_heads) * self.head_dim,
bias=False,
)
self.out_proj = nn.Linear(args.d_model, args.d_model, bias=False)
self.rope = nn.RoPE(
self.head_dim,
traditional=False,
base=self.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
qkv = self.Wqkv(x)
qkv = mx.clip(qkv, a_min=-self.clip_qkv, a_max=self.clip_qkv)
splits = [self.d_model, self.d_model + self.head_dim * self.num_key_value_heads]
queries, keys, values = mx.split(qkv, splits, axis=-1)
B, L, D = x.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
class NormAttnNorm(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.norm_1 = nn.LayerNorm(args.d_model, bias=False)
self.norm_2 = nn.LayerNorm(args.d_model, bias=False)
self.attn = Attention(args)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
x = h + x
return x, self.norm_2(x)
class MLP(nn.Module):
def __init__(self, d_model: int, ffn_dim: int):
super().__init__()
self.v1 = nn.Linear(d_model, ffn_dim, bias=False)
self.w1 = nn.Linear(d_model, ffn_dim, bias=False)
self.w2 = nn.Linear(ffn_dim, d_model, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
current_hidden_states = self.act_fn(self.w1(x)) * self.v1(x)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class Router(nn.Module):
def __init__(self, d_model: int, num_experts: int):
super().__init__()
self.layer = nn.Linear(d_model, num_experts, bias=False)
def __call__(self, x: mx.array):
return self.layer(x)
class SparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.d_model = args.d_model
self.ffn_dim = args.ffn_config["ffn_hidden_size"]
self.num_experts = args.ffn_config["moe_num_experts"]
self.num_experts_per_tok = args.ffn_config["moe_top_k"]
self.router = Router(self.d_model, self.num_experts)
self.experts = [
MLP(self.d_model, self.ffn_dim) for _ in range(self.num_experts)
]
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.router(x)
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True)
scores = scores.astype(x.dtype)
if self.training:
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt)
y = mx.stack(y, axis=0)
return y.reshape(orig_shape)
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ffn = SparseMoeBlock(args)
self.norm_attn_norm = NormAttnNorm(args)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r, h = self.norm_attn_norm(x, mask, cache)
out = self.ffn(h) + r
return out
class DBRX(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.vocab_size = args.vocab_size
self.wte = nn.Embedding(args.vocab_size, args.d_model)
self.blocks = [DecoderLayer(args=args) for _ in range(args.n_layers)]
self.norm_f = nn.LayerNorm(args.d_model, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
for layer, c in zip(self.blocks, cache):
h = layer(h, mask, c)
return self.norm_f(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.transformer = DBRX(args)
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, mask, cache)
return self.lm_head(out)
@property
def layers(self):
return self.transformer.blocks
def sanitize(self, weights):
# Split experts into sub matrices
num_experts = self.args.ffn_config["moe_num_experts"]
dim = self.args.ffn_config["ffn_hidden_size"]
pattern = "experts.mlp"
new_weights = {k: v for k, v in weights.items() if pattern not in k}
for k, v in weights.items():
if pattern in k:
experts = [
(k.replace(".mlp", f".{e}") + ".weight", sv)
for e, sv in enumerate(mx.split(v, num_experts, axis=0))
]
if k.endswith("w2"):
experts = [(s, sv.T) for s, sv in experts]
new_weights.update(experts)
return new_weights

View File

@@ -1,261 +0,0 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
attention_bias: bool = False
class DeepseekAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
attention_bias = getattr(config, "attention_bias", False)
self.q_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
rope_scale = 1.0
if config.rope_scaling and config.rope_scaling["type"] == "linear":
assert isinstance(config.rope_scaling["factor"], float)
rope_scale = 1 / config.rope_scaling["factor"]
self.rope = nn.RoPE(
self.head_dim,
base=config.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size or config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(scores, inds, axis=-1)
return inds, scores
class DeepseekMoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekAttention(config)
self.mlp = (
DeepseekMoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekMLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for m in ["gate_proj", "down_proj", "up_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -1,460 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v2"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "gready"
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)
class DeepseekV2YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
freq_inter = scaling_factor * base ** (
mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
low, high = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
)
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x = self.mscale * x
return mx.fast.rope(
x,
x.shape[-1],
traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class DeepseekV2Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV2YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV2MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
if self.topk_method == "group_limited_greedy":
bsz, seq_len = x.shape[:2]
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = scores.max(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV2MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV2DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV2Attention(config)
self.mlp = (
DeepseekV2MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV2MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekV2Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * self.num_layers
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@@ -1,478 +0,0 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v3"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "noaux_tc"
scoring_func: str = "sigmoid"
norm_topk_prob: bool = True
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)
class DeepseekV3YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
freq_inter = scaling_factor * base ** (
mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
low, high = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
)
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x = self.mscale * x
return mx.fast.rope(
x,
x.shape[-1],
traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
# A clipped silu to prevent fp16 from overflowing
@partial(mx.compile, shapeless=True)
def clipped_silu(x):
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV3YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV3MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.sigmoid(gates.astype(mx.float32))
assert self.topk_method == "noaux_tc", "Unsupported topk method."
bsz, seq_len = x.shape[:2]
scores = scores + self.e_score_correction_bias
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV3MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size,
config.moe_intermediate_size,
config.n_routed_experts,
activation=clipped_silu,
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV3MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV3DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV3Attention(config)
self.mlp = (
DeepseekV3MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV3MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r
class DeepseekV3Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * self.num_layers
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV3Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, cache, mask)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
# Remove multi-token prediction layer
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
@property
def layers(self):
return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@@ -1,166 +0,0 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_layers: int
intermediate_size: int
num_attention_heads: int
vocab_size: int
rope_theta: float
layer_norm_epsilon: float
num_key_value_heads: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
attention_bias: bool = False
mlp_bias: bool = False
class AttentionModule(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or (dim // n_heads)
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
) -> mx.array:
B, L, D = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
q = self.rope(q, offset=cache.offset)
k = self.rope(k, offset=cache.offset)
k, v = cache.update_and_fetch(k, v)
else:
q = self.rope(q)
k = self.rope(k)
out = scaled_dot_product_attention(
q, k, v, cache=cache, scale=self.scale, mask=mask
)
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
return self.out_proj(out)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = AttentionModule(args)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias)
def __call__(self, x: mx.array) -> mx.array:
return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = x + self.attn.attention(self.ln_1(x), mask, cache)
out = h + self.mlp(self.ln_2(h))
return out
class ExaoneModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_layers)]
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
h = layer(h, mask, cache=c)
return self.ln_f(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = ExaoneModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.h

View File

@@ -1,178 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class GemmaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = GemmaModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
return out
@property
def layers(self):
return self.model.layers

View File

@@ -1,203 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False
attn_logit_softcapping: float = 50.0
final_logit_softcapping: float = 30.0
query_pre_attn_scalar: float = 144.0
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
self.head_dim = head_dim = args.head_dim
self.scale = 1.0 / (args.query_pre_attn_scalar**0.5)
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.attn_logit_softcapping = args.attn_logit_softcapping
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries * self.scale
if self.repeats > 1:
queries = queries.reshape(
B, self.n_kv_heads, self.repeats, L, self.head_dim
)
keys = mx.expand_dims(keys, 2)
values = mx.expand_dims(values, 2)
scores = queries @ keys.swapaxes(-1, -2)
scores = mx.tanh(scores / self.attn_logit_softcapping)
scores *= self.attn_logit_softcapping
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, precise=True, axis=-1)
output = scores @ values
if self.repeats > 1:
output = output.reshape(B, self.n_heads, L, self.head_dim)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
class GemmaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.final_logit_softcapping = args.final_logit_softcapping
self.model = GemmaModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping
return out
@property
def layers(self):
return self.model.layers

View File

@@ -1,201 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_ctx: int
n_embd: int
n_head: int
n_layer: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.n_head
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
self.n_embd = args.n_embd
self.n_head = args.n_head
self.head_dim = self.n_embd // self.n_head
self.scale = self.head_dim**-0.5
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.layer_norm_epsilon = args.layer_norm_epsilon
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(
self.n_embd,
eps=self.layer_norm_epsilon,
)
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out
class GPT2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.n_positions = args.n_positions
self.vocab_size = args.vocab_size
self.n_layer = args.n_layer
self.layer_norm_epsilon = args.layer_norm_epsilon
assert self.vocab_size > 0
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
return self.ln_f(hidden_states)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPT2Model(args)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
out = self.model.wte.as_linear(out)
return out
def sanitize(self, weights):
new_weights = {}
for i in range(self.args.n_layer):
if f"h.{i}.attn.bias" in weights:
del weights[f"h.{i}.attn.bias"]
if f"h.{i}.attn.c_attn.weight" in weights:
weights[f"h.{i}.attn.c_attn.weight"] = weights[
f"h.{i}.attn.c_attn.weight"
].transpose(1, 0)
if f"h.{i}.attn.c_proj.weight" in weights:
weights[f"h.{i}.attn.c_proj.weight"] = weights[
f"h.{i}.attn.c_proj.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_fc.weight" in weights:
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
f"h.{i}.mlp.c_fc.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_proj.weight" in weights:
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
f"h.{i}.mlp.c_proj.weight"
].transpose(1, 0)
for weight in weights:
if not weight.startswith("model."):
new_weights[f"model.{weight}"] = weights[weight]
else:
new_weights[weight] = weights[weight]
return new_weights
@property
def layers(self):
return self.model.h

View File

@@ -1,189 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_embd: int
n_layer: int
n_inner: int
n_head: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
multi_query: bool = True
attention_bias: bool = True
mlp_bias: bool = True
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = 1 if self.multi_query else self.n_head
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = dim = args.n_embd
self.n_heads = n_heads = args.n_head
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
self.head_dim = head_dim = dim // n_heads
self.kv_dim = n_kv_heads * head_dim
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.n_embd
hidden_dim = args.n_inner
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out
class GPTBigCodeModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
B, L = inputs.shape
hidden_states = self.wte(inputs)
mask = None
if mask is not None and hidden_states.shape[1] > 1:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
position_ids = mx.array(np.arange(L))
else:
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
hidden_states += self.wpe(position_ids)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
return self.ln_f(hidden_states)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = GPTBigCodeModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.h

View File

@@ -1,219 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
# Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
max_position_embeddings: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
layer_norm_eps: float
vocab_size: int
rotary_emb_base: int
rotary_pct: float
num_key_value_heads: int = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert (
args.hidden_size % args.num_attention_heads == 0
), "hidden_size must be divisible by num_attention_heads"
self.hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.rope = nn.RoPE(
dims=int(self.head_dim * args.rotary_pct),
traditional=False,
base=args.rotary_emb_base,
)
self.scale = self.head_dim**-0.5
self.query_key_value = nn.Linear(
self.hidden_size, 3 * self.hidden_size, bias=True
)
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.query_key_value(x)
new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim)
qkv = qkv.reshape(*new_qkv_shape)
queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)]
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
def __call__(self, x) -> mx.array:
# gelu_approx corresponds to FastGELUActivation in transformers.
return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.layer_norm_eps = args.layer_norm_eps
self.attention = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.LayerNorm(
self.hidden_size,
eps=self.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm(
self.hidden_size, eps=self.layer_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
attn = self.attention(self.input_layernorm(x), mask, cache)
ffn = self.mlp(self.post_attention_layernorm(x))
out = attn + ffn + residual
return out
class GPTNeoXModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.layer_norm_eps = args.layer_norm_eps
assert self.vocab_size > 0
self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size)
self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)]
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
hidden_states = self.embed_in(inputs)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
out = self.final_layer_norm(hidden_states)
out = self.embed_out(out)
return out
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPTNeoXModel(args)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
return out
def sanitize(self, weights):
new_weights = {}
for w_key, w_value in weights.items():
# Created through register_buffer in Pytorch, not needed here.
ignore_suffixes = [
".attention.bias",
".attention.masked_bias",
".attention.rotary_emb.inv_freq",
]
skip_weight = False
for ignored_suffix in ignore_suffixes:
if w_key.endswith(ignored_suffix):
skip_weight = True
break
if skip_weight:
continue
if not w_key.startswith("model."):
w_key = f"model.{w_key}"
w_key = w_key.replace(".gpt_neox.layers.", ".h.")
w_key = w_key.replace(".gpt_neox.", ".")
new_weights[w_key] = w_value
return new_weights
@property
def layers(self):
return self.model.h

View File

@@ -1,185 +0,0 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
rms_norm_eps: float
vocab_size: int
attention_bias: bool
head_dim: int
max_position_embeddings: int
mlp_bias: bool
model_type: str
rope_theta: float
tie_word_embeddings: bool
class HeliumAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class HeliumMLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.intermediate_size = args.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class HeliumDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = HeliumAttention(args)
self.mlp = HeliumMLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class HeliumModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_hidden_layers = args.num_hidden_layers
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = HeliumModel(args)
self.vocab_size = args.vocab_size
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

View File

@@ -1,294 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
attention_bias: bool
moe_topk: int
num_experts: int
num_shared_expert: int
use_mixed_mlp_moe: bool
use_qk_norm: bool
rms_norm_eps: float
rope_theta: float
use_cla: bool
cla_share_factor: 2
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
class DynamicNTKAlphaRoPE(nn.Module):
def __init__(
self,
dims: int,
base: float = 10000,
scaling_alpha: float = 1.0,
):
super().__init__()
self.dims = dims
base = base * scaling_alpha ** (dims / (dims - 2))
self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=False,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class Attention(nn.Module):
def __init__(self, kv_proj: bool, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
if kv_proj:
self.k_proj = nn.Linear(
dim, n_kv_heads * head_dim, bias=args.attention_bias
)
self.v_proj = nn.Linear(
dim, n_kv_heads * head_dim, bias=args.attention_bias
)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.use_qk_norm = args.use_qk_norm
if self.use_qk_norm:
self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
self.rope = DynamicNTKAlphaRoPE(
head_dim,
base=args.rope_theta,
scaling_alpha=args.rope_scaling["alpha"],
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
kv_states=None,
) -> mx.array:
B, L, D = x.shape
queries = self.q_proj(x)
if kv_states is None:
keys, values = self.k_proj(x), self.v_proj(x)
kv_states = keys, values
else:
keys, values = kv_states
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
offset = cache.offset if cache else 0
queries = self.rope(queries, offset=offset)
keys = self.rope(keys, offset=offset)
if self.use_qk_norm:
queries = self.query_layernorm(queries)
keys = self.key_layernorm(keys)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), kv_states
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Gate(nn.Module):
def __init__(self, dim, num_experts):
super().__init__()
self.wg = nn.Linear(dim, num_experts, bias=False)
def __call__(self, x) -> mx.array:
return self.wg(x)
class MoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
intermediate_size = args.intermediate_size
self.use_shared_mlp = args.use_mixed_mlp_moe
if args.use_mixed_mlp_moe:
self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert)
self.num_experts = num_experts = args.num_experts
self.top_k = args.moe_topk
self.gate = Gate(dim, num_experts)
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
def __call__(
self,
x: mx.array,
):
gates = self.gate(x)
gates = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.use_shared_mlp:
shared_expert_output = self.shared_mlp(x)
y = y + shared_expert_output
return y
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs, kv_proj: bool):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(kv_proj, args)
self.mlp = MoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None,
):
r, shared_kv_states = self.self_attn(
self.input_layernorm(x), mask, cache, shared_kv_states
)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, shared_kv_states
class HunYuanModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
for i in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)):
if i % self.args.cla_share_factor == 0:
shared_kv_states = None
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = HunYuanModel(args)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -1,241 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
bias: bool = True
max_position_embeddings: int = 32768
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] not in ["linear", "dynamic"]:
raise ValueError(
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.max_position_embeddings = max_position_embeddings
self.original_base = base
self.dims = dims
self.traditional = traditional
self.scale = scale
def extra_repr(self):
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
if seq_len > self.max_position_embeddings:
base = self.original_base * (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
else:
base = self.original_base
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
scale=self.scale,
offset=offset,
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_kv_groups = n_heads // args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.wqkv = nn.Linear(
dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=args.bias
)
self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 2.0
)
self.rope = DynamicNTKScalingRoPE(
head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv_states = self.wqkv(x)
qkv_states = qkv_states.reshape(B, L, -1, 2 + self.n_kv_groups, self.head_dim)
queries = qkv_states[..., : self.n_kv_groups, :]
queries = queries.reshape(B, L, -1, self.head_dim)
keys = qkv_states[..., -2, :]
values = qkv_states[..., -1, :]
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = MLP(args.hidden_size, args.intermediate_size)
self.attention_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out
class InternLM2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.tok_embeddings(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = InternLM2Model(args)
if not args.tie_word_embeddings:
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.tok_embeddings.as_linear(out)
else:
out = self.output(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
@property
def layers(self):
return self.model.layers

View File

@@ -1,241 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
bias: bool = False
qkv_bias: bool = False
max_position_embeddings: int = 32768
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "rope_type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]:
raise ValueError(
"rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.max_position_embeddings = max_position_embeddings
self.original_base = base
self.dims = dims
self.traditional = traditional
self.scale = scale
def extra_repr(self):
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
if seq_len > self.max_position_embeddings:
base = self.original_base * (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
else:
base = self.original_base
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
scale=self.scale,
offset=offset,
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
qkv_bias = args.qkv_bias
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_kv_groups = n_heads // args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None
and args.rope_scaling["rope_type"] == "linear"
else 2.0
)
self.rope = DynamicNTKScalingRoPE(
head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim, bias):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class InternLM2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = InternLM2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
@property
def layers(self):
return self.model.layers

View File

@@ -1,205 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = LlamaModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

View File

@@ -1,242 +0,0 @@
# Copyright © 2024-2025 Apple Inc.
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
state_size: int
num_hidden_layers: int
conv_kernel: int
use_bias: bool
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
self.hidden_size = self.d_model
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
self.intermediate_size = self.d_inner
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
self.state_size = self.d_state
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
self.num_hidden_layers = self.n_layer
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
self.num_hidden_layers = self.n_layers
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
self.conv_kernel = self.d_conv
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
self.use_bias = self.bias
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
self.use_conv_bias = self.conv_bias
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
groups, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.ssm_state_size = args.state_size
self.conv_kernel_size = args.conv_kernel
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_norm = lambda x: mx.fast.rms_norm(
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
)
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
)
self.conv1d = DepthWiseConv1d(
channels=self.intermediate_size,
kernel_size=self.conv_kernel_size,
bias=self.use_conv_bias,
padding=self.conv_kernel_size - 1,
)
self.x_proj = nn.Linear(
self.intermediate_size,
self.time_step_rank + 2 * self.ssm_state_size,
bias=False,
)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
A = mx.repeat(
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
repeats=self.intermediate_size,
axis=0,
)
self.A_log = mx.log(A)
self.D = mx.ones([self.intermediate_size])
self.out_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.use_bias
)
def ssm_step(self, x, A, state=None):
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = map(
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split(
deltaBC,
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
axis=-1,
),
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x
return y, new_state
def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape
xz = self.in_proj(x)
x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
outputs = []
current_state = state_cache
y = []
for t in range(T):
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
y.append(y_t)
y = mx.stack(y, axis=1)
z = self.out_proj(nn.silu(z) * y)
return z, (new_conv_cache, current_state)
def __call__(self, x, cache):
if cache is None:
conv_cache, state_cache = None, None
else:
conv_cache, state_cache = cache[0], cache[1]
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = MambaBlock(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@@ -1,210 +0,0 @@
# Copyright © 2023-2025 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
dim_model_base: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
scale_depth: float
scale_emb: float
rope_theta: float = 1000000.0
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[str, float]]] = None
tie_word_embeddings: bool = False
class MLP(nn.Module):
def __init__(self, args):
super().__init__()
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.num_heads = n_heads = args.num_attention_heads
self.rope_theta = args.rope_theta
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.num_key_value_heads = args.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
dims=self.head_dim,
traditional=args.rope_traditional,
base=self.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
):
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
attn_output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(attn_output)
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.num_hidden_layers = args.num_hidden_layers
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.scale_depth = args.scale_depth
self.num_hidden_layers = args.num_hidden_layers
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
r = self.mlp(self.post_attention_layernorm(h))
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
return out
class MiniCPMModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs) * self.args.scale_emb
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = MiniCPMModel(args)
if not self.args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
else:
out = out @ self.model.embed_tokens.weight.T
return out
def sanitize(self, weights):
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -1,220 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int = 32000
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_experts_per_tok: int = 2
num_key_value_heads: int = 8
num_local_experts: int = 8
rms_norm_eps: float = 1e-5
rope_theta: float = 1e6
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class MixtralAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.num_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = args.num_key_value_heads
self.rope_theta = args.rope_theta
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rope = nn.RoPE(
self.head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
k = self.num_experts_per_tok
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y
class MixtralDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = MixtralAttention(args)
self.block_sparse_moe = MixtralSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.block_sparse_moe(self.post_attention_layernorm(h))
out = h + r
return out
class MixtralModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = MixtralModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -1,220 +0,0 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
hidden_act: str
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
norm_eps: float
vocab_size: int
num_key_value_heads: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
partial_rotary_factor: float = 0.5
rope_theta: float = 10000.0
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear"]:
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@partial(mx.compile, shapeless=True)
def relu_squared(x):
return nn.relu(x).square()
class NemotronLayerNorm1P(nn.LayerNorm):
def __call__(self, x):
weight = self.weight + 1 if "weight" in self else None
bias = self.bias if "bias" in self else None
return mx.fast.layer_norm(x, weight, bias, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.partial_rotary_factor = args.partial_rotary_factor
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
mlp_bias = args.mlp_bias
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(relu_squared(self.up_proj(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
self.post_attention_layernorm = NemotronLayerNorm1P(
args.hidden_size, eps=args.norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class NemotronModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = NemotronModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

View File

@@ -1,180 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import sys
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
sys.exit(1)
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
d_model: int
n_layers: int
mlp_hidden_size: int
n_heads: int
vocab_size: int
embedding_size: int
rope_theta: float = 10000
rope_traditional: bool = False
mlp_ratio: int = 4
weight_tying: bool = False
def __post_init__(self):
self.mlp_hidden_size = (
self.mlp_hidden_size
if self.mlp_hidden_size is not None
else self.mlp_ratio * self.d_model
)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
dim = args.d_model
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
self.att_norm = nn.LayerNorm(dim, affine=False)
self.ff_norm = nn.LayerNorm(dim, affine=False)
head_dim = dim // self.n_heads
self.scale = head_dim**-0.5
self.att_proj = nn.Linear(dim, 3 * dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
self.args = args
def attend(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.attn_out(output)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attend(self.att_norm(x), mask, cache)
h = x + r
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
out = h + self.ff_out(nn.silu(x2) * x1)
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_layers = args.n_layers
self.weight_tying = args.weight_tying
self.wte = nn.Embedding(args.embedding_size, args.d_model)
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
if not self.weight_tying:
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
self.norm = nn.LayerNorm(args.d_model, affine=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.wte(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
for block, c in zip(self.blocks, cache):
h = block(h, mask, c)
h = self.norm(h)
if self.weight_tying:
return self.wte.as_linear(h), cache
return self.ff_out(h)
class OlmoModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.transformer = Transformer(args)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.transformer(inputs, mask, cache)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = OlmoModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
return self.model(inputs, mask, cache)
@property
def layers(self):
return self.model.transformer.blocks

View File

@@ -1,212 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.post_attention_layernorm(self.self_attn(x, mask, cache))
h = x + r
r = self.post_feedforward_layernorm(self.mlp(h))
out = h + r
return out
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
mask=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = LlamaModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
mask=None,
):
out = self.model(inputs, cache, mask)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

View File

@@ -1,223 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
head_dim: int
num_transformer_layers: int
model_dim: int
vocab_size: int
ffn_dim_divisor: int
num_query_heads: List
num_kv_heads: List
ffn_multipliers: List
ffn_with_glu: bool = True
normalize_qk_projections: bool = True
share_input_output_layers: bool = True
rms_norm_eps: float = 1e-6
rope_freq_constant: float = 10000
def make_divisible(
v: Union[float, int],
divisor: Optional[int] = 8,
min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by the divisor
It can be seen at:
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
Args:
v: input value
divisor: default to 8
min_value: minimum divisor value
Returns:
new_v: new divisible value
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.head_dim = head_dim = args.head_dim
self.layer_id = layer_id
self.model_dim = model_dim = args.model_dim
self.n_heads = n_heads = args.num_query_heads[layer_id]
self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id]
self.scale = head_dim**-0.5
op_size = (n_heads + (n_kv_heads * 2)) * head_dim
self.qkv_proj = nn.Linear(model_dim, op_size, bias=False)
self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False)
self.normalize_qk_projections = args.normalize_qk_projections
if self.normalize_qk_projections:
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.qkv_proj(x)
qkv = qkv.reshape(
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
).transpose(0, 2, 1, 3)
queries, keys, values = mx.split(
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
)
# Prepare the queries, keys and values for the attention computation
if self.normalize_qk_projections:
queries = self.q_norm(queries)
keys = self.k_norm(keys)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.args = args
dim = args.model_dim
ffn_multiplier = args.ffn_multipliers[layer_id]
intermediate_dim = int(
make_divisible(
ffn_multiplier * args.model_dim,
divisor=args.ffn_dim_divisor,
)
)
self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False)
self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False)
def __call__(self, x) -> mx.array:
x = self.proj_1(x)
gate, x = mx.split(x, 2, axis=-1)
return self.proj_2(nn.silu(gate) * x)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
dim = args.model_dim
self.attn = Attention(args, layer_id=layer_id)
self.ffn = MLP(args, layer_id=layer_id)
self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
r = self.ffn(self.ffn_norm(h))
out = h + r
return out
class OpenELMModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_transformer_layers = args.num_transformer_layers
assert self.vocab_size > 0
self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim)
self.layers = [
TransformerBlock(args, layer_id=layer_id)
for layer_id in range(self.num_transformer_layers)
]
self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.token_embeddings(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = OpenELMModel(args)
if not args.share_input_output_layers:
self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.transformer(inputs, mask, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.layers

View File

@@ -1,179 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "phi"
max_position_embeddings: int = 2048
vocab_size: int = 51200
hidden_size: int = 2560
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: int = 32
partial_rotary_factor: float = 0.4
intermediate_size: int = 10240
layer_norm_eps: float = 1e-5
rope_theta: float = 10000.0
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class PhiAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.repeats = self.num_heads // self.num_key_value_heads
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=True
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.dense = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=True
)
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
traditional=False,
base=self.rope_theta,
)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
B, L, D = queries.shape
n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(
B,
L,
n_heads,
-1,
).moveaxis(1, 2)
keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention(
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)
return self.dense(output)
class PhiMLP(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class PhiDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.self_attn = PhiAttention(config=config)
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.mlp = PhiMLP(config)
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x
class PhiModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
self.final_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, mask, c)
return self.final_layernorm(x)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.model = PhiModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.args = config
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
y = self.model(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers

View File

@@ -1,207 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: Optional[int] = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"long_factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
print(
"[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
)
self.rope_scaling = None
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
head_dim,
base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.qkv_proj(x)
query_pos = self.n_heads * self.head_dim
queries, keys, values = mx.split(
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
def __call__(self, x) -> mx.array:
x = self.gate_up_proj(x)
gate, x = mx.split(x, 2, axis=-1)
return self.down_proj(nn.silu(gate) * x)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Phi3Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers

View File

@@ -1,313 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
dense_attention_every_n_layers: int
ff_intermediate_size: int
gegelu_limit: float
num_hidden_layers: int
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: int = 64
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@partial(mx.compile, shapeless=True)
def gegelu_impl(a_gelu, a_linear, limit):
a_gelu = mx.where(
mx.isinf(a_gelu),
a_gelu,
mx.clip(a_gelu, a_min=None, a_max=limit),
)
a_linear = mx.where(
mx.isinf(a_linear),
a_linear,
mx.clip(a_linear, a_min=-limit, a_max=limit),
)
out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
return out_gelu * (a_linear + 1.0)
def gegelu(x, limit):
a_gelu, a_linear = x[..., ::2], x[..., 1::2]
return gegelu_impl(a_gelu, a_linear, limit)
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.query_key_value = nn.Linear(
dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
)
self.dense = nn.Linear(dim, dim)
if args.mup_use_scaling:
norm_factor = head_dim / args.mup_attn_multiplier
else:
norm_factor = math.sqrt(head_dim)
self.scale = 1.0 / norm_factor
self.rope = nn.RoPE(
head_dim,
traditional=False,
base=args.rope_embedding_base,
scale=args.rope_position_scale,
)
if layer_idx % args.dense_attention_every_n_layers == 0:
self.block_sparse = True
self.blocksparse_block_size = args.blocksparse_block_size
if self.blocksparse_block_size not in (32, 64):
raise ValueError(
f"Unsupported block size {self.blocksparse_block_size}"
)
self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
self.blocksparse_vert_stride = args.blocksparse_vert_stride
else:
self.block_sparse = False
def _block_sparse_mask(self, q_len, kv_len):
vert_stride = self.blocksparse_vert_stride
local_blocks = self.blocksparse_num_local_blocks
block_size = self.blocksparse_block_size
n_heads = self.n_heads
kv_blocks = (kv_len + block_size - 1) // block_size
q_blocks = (q_len + block_size - 1) // block_size
q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
k_pos = mx.arange(kv_blocks)[None, None]
mask_vert_strided = (
mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
) % vert_stride
mask_vert_strided = (mask_vert_strided == 0)[:, None, :]
block_mask = (q_pos >= k_pos) & (
(q_pos - k_pos < local_blocks) | mask_vert_strided
)
block_mask = block_mask.reshape(
self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
)
dense_mask = mx.repeat(
mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
)
return block_mask, dense_mask[..., -q_len:, :kv_len]
def _block_sparse_attention(self, queries, keys, values, scale, mask):
queries = scale * queries
B = queries.shape[0]
L = queries.shape[2]
queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
keys = mx.expand_dims(keys, 2)
values = mx.expand_dims(values, 2)
# TODO get rid of dense mask if we have a fill value
block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
scores = queries @ mx.swapaxes(keys, -1, -2)
# TODO, uncomment when faster
# scores = mx.block_masked_mm(
# queries,
# mx.swapaxes(keys, -1, -2),
# mask_out=block_mask,
# block_size=self.blocksparse_block_size,
# )
if mask is not None:
scores = scores + mask
scores = scores + mx.where(
dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
)
scores = mx.softmax(scores, axis=-1, precise=True)
output = scores @ values
# TODO, uncomment when faster
# output = mx.block_masked_mm(
# scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
# )
return mx.reshape(output, (B, self.n_heads, L, -1))
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.query_key_value(x)
qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
queries = qkv[..., :-2, :].flatten(-3, -2)
keys = qkv[..., -2, :]
values = qkv[..., -1, :]
# Prepare the queries, keys and values for the attention computation
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
if self.block_sparse:
output = self._block_sparse_attention(
queries, keys, values, scale=self.scale, mask=mask
)
else:
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output)
class MLP(nn.Module):
def __init__(self, args):
super().__init__()
dim = args.hidden_size
hidden_dim = args.ff_intermediate_size
self.gegelu_limit = args.gegelu_limit
self.up_proj = nn.Linear(dim, 2 * hidden_dim)
self.down_proj = nn.Linear(hidden_dim, dim)
def __call__(self, x) -> mx.array:
x = self.up_proj(x)
return self.down_proj(gegelu(x, self.gegelu_limit))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_epsilon
)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size,
eps=args.layer_norm_epsilon,
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.mup_embedding_multiplier = args.mup_embedding_multiplier
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=l)
for l in range(args.num_hidden_layers)
]
self.final_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_epsilon
)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.final_layernorm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Phi3Model(args)
self.args = args
self.mup_width_multiplier = args.mup_width_multiplier
self._dummy_tokenizer_ids = mx.array(
[100256, 100258, 100259, 100260, 100264, 100265]
+ list(range(100267, 100352))
)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
out = self.model.embed_tokens.as_linear(out)
if self.mup_width_multiplier:
out = out / self.mup_width_multiplier
out[self._dummy_tokenizer_ids] = -float("inf")
return out
@property
def layers(self):
return self.model.layers
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

View File

@@ -1,214 +0,0 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "phimoe"
vocab_size: int = 32064
hidden_size: int = 4096
intermediate_size: int = 6400
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
rms_norm_eps: float = 1e-6
rope_scaling: Dict[str, Union[float, List[float]]] = None
num_local_experts: int = 16
num_experts_per_tok: int = 2
rope_theta: float = 10000.0
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope = SuScaledRotaryEmbedding(
head_dim,
base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
short_mscale=args.rope_scaling["short_mscale"],
long_mscale=args.rope_scaling["long_mscale"],
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache=None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class PhiMoESparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.top_k = args.num_experts_per_tok
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y
class PhiMoEDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.block_sparse_moe = PhiMoESparseMoeBlock(args)
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache=None,
) -> mx.array:
residual = x
hidden_states = self.input_layernorm(x)
hidden_states = self.self_attn(hidden_states, mask=mask, cache=cache)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class PhiMoEModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.args = args
self.model = PhiMoEModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -1,202 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import inspect
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchMLP
@dataclass
class ModelArgs:
model_type: str
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
num_layers: int = 32
rotary_dim: int = 32
num_experts_per_tok: int = 2
num_local_experts: int = 4
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.Wqkv = nn.Linear(dims, 3 * dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention(
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)
return self.out_proj(output)
class MOE(nn.Module):
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
self.switch_mlp = SwitchMLP(
self.dim, self.hidden_dim, self.num_experts, bias=True
)
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
k = self.num_experts_per_tok
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = nn.LayerNorm(dims)
self.moe = MOE(config, dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h = self.mixer(h, mask, cache)
ff_h = self.moe(h)
return attn_h + ff_h + x
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embd = Embd(config)
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache):
x = self.embd(x)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
x = layer(x, mask, c)
return x
class Embd(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
def __call__(self, x):
return self.wte(x)
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.ln = nn.LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):
return self.linear(self.ln(inputs))
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
self.args = config
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
if mask is None:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)
def sanitize(self, weights):
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
return weights
for l in range(self.args.num_layers):
prefix = f"transformer.h.{l}"
for n in ["fc1", "fc2"]:
for k in ["weight", "scales", "biases", "bias"]:
if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.transformer.h

View File

@@ -1,214 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
n_shared_head: int = 8
rope_theta: float = 10000
rope_traditional: bool = False
class Attention(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
head_dim = self.hidden_size // config.num_attention_heads
self.q_num_heads = config.num_attention_heads
self.qk_dim = self.v_dim = head_dim
self.k_num_heads = self.v_num_heads = int(
np.ceil(self.q_num_heads / config.n_shared_head)
)
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(
self.hidden_size, self.q_num_heads * self.qk_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.k_num_heads * self.qk_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.v_num_heads * self.v_dim, bias=False
)
self.o_proj = nn.Linear(
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
)
self.rotary_emb = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=1.0,
)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
bsz, q_len, _ = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
0, 2, 1, 3
)
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
0, 2, 1, 3
)
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
0, 2, 1, 3
)
if cache is not None:
queries = self.rotary_emb(queries, offset=cache.offset)
keys = self.rotary_emb(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = scaled_dot_product_attention(
queries,
keys,
values,
cache=cache,
scale=self.scale,
mask=attention_mask,
)
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
class PlamoDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.mlp = MLP(config)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
):
# from LlamaDecoder
residual = hidden_states
hidden_states = self.norm(hidden_states)
# Self Attention
hidden_states_sa = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
cache=cache,
)
# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
hidden_states = residual + hidden_states_sa + hidden_states_mlp
return hidden_states
class PlamoDecoder(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.layers = [
PlamoDecoderLayer(config) for _ in range(config.num_hidden_layers)
]
class PlamoModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = PlamoDecoder(config) # type: ignore
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
for layer, c in zip(self.layers.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model_type = args.model_type
self.model = PlamoModel(args)
self.lm_head: nn.Module = nn.Linear(
args.hidden_size, args.vocab_size, bias=False
)
self.args = args
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
out = self.model(inputs, cache, mask)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers

View File

@@ -1,159 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 2048
num_attention_heads: int = 16
num_hidden_layers: int = 24
kv_channels: int = 128
max_position_embeddings: int = 8192
layer_norm_epsilon: float = 1e-6
intermediate_size: int = 11008
no_bias: bool = True
vocab_size: int = 151936
num_key_value_heads = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
proj_size = args.kv_channels * self.num_attention_heads
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
self.scale = hidden_size_per_attention_head**-0.5
def __call__(self, x, mask=None, cache=None):
qkv = self.c_attn(x)
q, k, v = mx.split(qkv, 3, axis=-1)
B, L, _ = q.shape
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rotary_emb(queries, offset=cache.offset)
keys = self.rotary_emb(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
)
self.w2 = nn.Linear(
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
)
self.c_proj = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
)
def __call__(self, x):
a1 = self.w1(x)
a2 = self.w2(x)
return self.c_proj(a1 * nn.silu(a2))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args)
def __call__(self, x, mask=None, cache=None):
residual = x
x = self.ln_1(x)
x = self.attn(x, mask=mask, cache=cache)
residual = x + residual
x = self.ln_2(residual)
x = self.mlp(x)
x = x + residual
return x
class QwenModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
x = layer(x, mask, c)
return self.ln_f(x)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.transformer = QwenModel(config)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=not config.no_bias
)
self.args = config
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h

View File

@@ -1,201 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Qwen2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

View File

@@ -1,241 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_experts_per_tok: int
num_experts: int
moe_intermediate_size: int
shared_expert_intermediate_size: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
intermediate_size = args.moe_intermediate_size
shared_expert_intermediate_size = args.shared_expert_intermediate_size
self.num_experts = num_experts = args.num_experts
self.top_k = args.num_experts_per_tok
self.gate = nn.Linear(dim, num_experts, bias=False)
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
def __call__(
self,
x: mx.array,
):
gates = self.gate(x)
gates = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
shared_expert_output = self.shared_expert(x)
shared_expert_output = (
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
)
return y + shared_expert_output
class Qwen2MoeDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = Qwen2MoeSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Qwen2MoeModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
Qwen2MoeDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Qwen2MoeModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -1,458 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import MambaCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
attention_bias: bool
conv1d_width: int
hidden_size: int
intermediate_size: int
logits_soft_cap: float
num_attention_heads: int
num_hidden_layers: int
num_key_value_heads: int
rms_norm_eps: float
rope_theta: float
attention_window_size: int
vocab_size: int
embeddings_scale_by_sqrt_dim: bool = True
block_types: Optional[List[str]] = None
_block_types: Optional[List[str]] = None
def __post_init__(self):
# For some reason these have different names in 2B and 9B
if self.block_types is None:
self.block_types = self._block_types
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
def rnn_scan(x, a, h0):
assert x.ndim == 3
assert a.shape == x.shape[-a.ndim :]
assert a.dtype == x.dtype
if x.shape[1] == 1:
# Using scan in sampling mode.
if h0 is None:
return x, x[:, 0]
else:
y = a * h0[:, None] + x
return y, y[:, -1]
else:
# Using scan in linear mode.
if h0 is not None:
h_t = h0
else:
B, _, D = x.shape
h_t = mx.zeros((B, D), dtype=x.dtype)
y = mx.zeros_like(x)
for t in range(x.shape[1]):
h_t = a[:, t] * h_t + x[:, t]
y[:, t] = h_t
return y, h_t
class Conv1d(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int,
):
super().__init__()
self.weight = mx.zeros((channels, kernel_size, 1))
self.bias = mx.zeros((channels,))
def __call__(self, x, cache=None):
B, L, C = x.shape
groups, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
y = y + self.bias
return y, x[:, -K + 1 :, :]
class RGLRU(nn.Module):
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
def __init__(
self,
width: int,
num_heads: int,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.head_dim = self.width // self.num_heads
self.recurrent_param = mx.zeros((self.width,))
self.input_gate_weight = mx.zeros(
(self.num_heads, self.head_dim, self.head_dim),
)
self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim))
self.recurrent_gate_weight = mx.zeros(
(self.num_heads, self.head_dim, self.head_dim),
)
self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim))
def __call__(
self,
x: mx.array,
cache=None,
):
B, L, _ = x.shape
def apply_block_linear(h, w, b):
h = h.reshape((B, L, self.num_heads, self.head_dim))
h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b
return mx.sigmoid(h.flatten(2, 3))
# Gates for x and a.
gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias)
gate_a = apply_block_linear(
x, self.recurrent_gate_weight, self.recurrent_gate_bias
)
# Compute the parameter `A` of the recurrence.
log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param)
a = mx.exp(log_a)
a_square = mx.exp(2 * log_a)
# Gate the input.
gated_x = x * gate_x
# Apply gamma normalization to the input.
multiplier = mx.sqrt(1 - a_square)
if cache is None:
multiplier[:, 0, :] = 1.0
normalized_x = gated_x * multiplier.astype(x.dtype)
y, last_h = rnn_scan(
x=normalized_x,
a=a,
h0=cache,
)
return y, last_h
class RecurrentBlock(nn.Module):
def __init__(
self,
width: int,
num_heads: int,
lru_width: int = None,
conv1d_temporal_width: int = 4,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.lru_width = lru_width or width
self.conv1d_temporal_width = conv1d_temporal_width
self.linear_y = nn.Linear(width, self.lru_width)
self.linear_x = nn.Linear(width, self.lru_width)
self.linear_out = nn.Linear(self.lru_width, width)
self.conv_1d = Conv1d(
channels=self.lru_width,
kernel_size=self.conv1d_temporal_width,
)
self.rg_lru = RGLRU(
width=self.lru_width,
num_heads=self.num_heads,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
# y branch.
y = self.linear_y(x)
y = nn.gelu_approx(y)
# x branch.
x = self.linear_x(x)
if cache is None:
cache = [None, None]
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
x = x * y
x = self.linear_out(x)
return x
class LocalAttentionBlock(nn.Module):
def __init__(
self,
width: int,
num_heads: int,
window_size: int,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.window_size = window_size
self.scale = (width // num_heads) ** (-0.5)
self.head_dim = self.width // self.num_heads
self.q_proj = nn.Linear(self.width, self.width, bias=False)
self.k_proj = nn.Linear(self.width, self.head_dim, bias=False)
self.v_proj = nn.Linear(self.width, self.head_dim, bias=False)
self.o_proj = nn.Linear(self.width, self.width, bias=True)
self.rope = nn.RoPE(
self.head_dim // 2,
traditional=False,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLPBlock(nn.Module):
def __init__(self, width: int, expanded_width: int):
super().__init__()
self.up_proj = nn.Linear(width, expanded_width // 2)
self.gate_proj = nn.Linear(width, expanded_width // 2)
self.down_proj = nn.Linear(expanded_width // 2, width)
def __call__(self, x: mx.array):
gate = self.gate_proj(x)
x = self.up_proj(x)
return self.down_proj(nn.gelu_approx(gate) * x)
class ResidualBlock(nn.Module):
def __init__(
self,
width: int,
mlp_expanded_width: int,
num_heads: int,
attention_window_size: int,
temporal_block_type: str,
lru_width: Optional[int] = None,
conv1d_temporal_width: int = 4,
):
"""Initializes the residual block.
Args:
width: The width of the block.
mlp_expanded_width: The width of the expansion inside the MLP block.
num_heads: The number of heads for the Attention or the RG-LRU.
attention_window_size: The window size for the local attention block.
temporal_block_type: Either "recurrent" or "attention", specifying the
type of recurrent block to use.
lru_width: The width of the RG-LRU if different from `width`.
conv1d_temporal_width: The width of the temporal convolution.
"""
super().__init__()
self.width = width
self.mlp_expanded_width = mlp_expanded_width
self.num_heads = num_heads
self.attention_window_size = attention_window_size
self.temporal_block_type = temporal_block_type
self.lru_width = lru_width
self.conv1d_temporal_width = conv1d_temporal_width
self.temporal_pre_norm = RMSNorm(width)
if self.temporal_block_type == "recurrent":
self.temporal_block = RecurrentBlock(
width=self.width,
num_heads=self.num_heads,
lru_width=self.lru_width,
conv1d_temporal_width=self.conv1d_temporal_width,
)
else:
self.temporal_block = LocalAttentionBlock(
width=self.width,
num_heads=self.num_heads,
window_size=self.attention_window_size,
)
self.channel_pre_norm = RMSNorm(width)
self.mlp_block = MLPBlock(
width=self.width,
expanded_width=self.mlp_expanded_width,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
raw_x = x
inputs_normalized = self.temporal_pre_norm(raw_x)
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
residual = x + raw_x
x = self.channel_pre_norm(residual)
x = self.mlp_block(x)
x = x + residual
return x
class Griffin(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(
config.vocab_size,
config.hidden_size,
)
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
block_types = config.block_types
self.layers = [
ResidualBlock(
width=config.hidden_size,
mlp_expanded_width=config.intermediate_size,
num_heads=config.num_attention_heads,
attention_window_size=config.attention_window_size,
temporal_block_type=block_types[i % len(block_types)],
lru_width=None,
)
for i in range(config.num_hidden_layers)
]
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
tokens,
mask: mx.array = None,
cache=None,
):
x = self.embed_tokens(tokens)
if self.scale_by_sqrt_dim:
x = x * math.sqrt(x.shape[-1])
if cache is None:
cache = [None] * len(self.layers)
for i, block in enumerate(self.layers):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
if mask is None:
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
return self.final_norm(x)
class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
logits = self.model(tokens, mask=mask, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:
logits = self.model.embed_tokens.as_linear(logits)
c = self.args.logits_soft_cap
if c:
logits = mx.tanh(logits / c) * c
return logits
@property
def layers(self):
return self.model.layers
def sanitize(self, weights):
for k, v in weights.items():
if "conv_1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
def make_cache(self):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
cache.append(MambaCache())
else:
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache

View File

@@ -1,91 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
class Llama3RoPE(nn.Module):
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scaling_config: dict = None,
):
super().__init__()
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
factor = scaling_config["factor"]
low_freq_factor = scaling_config.get("low_freq_factor", 1.0)
high_freq_factor = scaling_config.get("high_freq_factor", 4.0)
old_context_len = scaling_config.get(
"original_max_position_embeddings",
8192,
)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = base ** (mx.arange(0, dims, 2) / dims)
wavelens = 2 * mx.pi * freqs
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
def extra_repr(self):
return (
f"{self.dims}, traditional={self.traditional}, "
f"max_position_embeddings={self.max_position_embeddings}"
)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
def initialize_rope(
dims,
base,
traditional,
scaling_config: Optional[dict] = None,
max_position_embeddings: Optional[int] = None,
):
if scaling_config is not None:
rope_type = scaling_config.get("type") or scaling_config.get(
"rope_type", "default"
)
else:
rope_type = "default"
if rope_type in ["default", "linear"]:
scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0
return nn.RoPE(dims, traditional=traditional, base=base, scale=scale)
elif rope_type == "llama3":
return Llama3RoPE(
dims=dims,
max_position_embeddings=max_position_embeddings,
traditional=traditional,
base=base,
scaling_config=scaling_config,
)
else:
raise ValueError(f"Unsupported RoPE type {rope_type}")

View File

@@ -1,211 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
num_key_value_heads: int
intermediate_size: int
rope_theta: float
use_qkv_bias: bool
partial_rotary_factor: float
layer_norm_eps: float
use_parallel_residual: bool = False
qk_layernorm: bool = False
class LayerNormPerHead(nn.Module):
def __init__(self, head_dim, num_heads, eps):
super().__init__()
self.norms = [
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
]
self.eps = eps
def __call__(self, x):
w = mx.stack([n.weight for n in self.norms])
return w * mx.fast.layer_norm(x, None, None, self.eps)
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.use_qkv_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.use_qkv_bias,
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
traditional=False,
base=self.rope_theta,
)
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = LayerNormPerHead(
self.head_dim, self.num_heads, eps=config.layer_norm_eps
)
self.k_layernorm = LayerNormPerHead(
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
B, L, D = queries.shape
queries = queries.reshape(B, L, self.num_heads, -1)
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
if self.qk_layernorm:
queries = self.q_layernorm(queries)
keys = self.k_layernorm(keys)
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=scale, mask=mask
).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.self_attn = Attention(config=config)
self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
self.use_parallel_residual = config.use_parallel_residual
if not self.use_parallel_residual:
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
r = self.self_attn(h, mask, cache)
if self.use_parallel_residual:
out = x + r + self.mlp(h)
else:
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class StableLM(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, mask, cache=c)
return self.norm(x)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.model = StableLM(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.args = config
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
if mask is None:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers

View File

@@ -1,169 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
norm_epsilon: float = 1e-5
vocab_size: int = 49152
rope_theta: float = 100000
tie_word_embeddings: bool = True
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // args.num_attention_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)
def __call__(self, x):
return self.c_proj(nn.gelu(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.norm_epsilon
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Starcoder2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Starcoder2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

Some files were not shown because too many files have changed in this diff Show More