mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
51 Commits
distribute
...
13c3892aae
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13c3892aae | ||
|
|
8e293bbc51 | ||
|
|
4b2a0df237 | ||
|
|
977cd30242 | ||
|
|
4c9f9f9be7 | ||
|
|
c52cc748f8 | ||
|
|
c243370044 | ||
|
|
7ca05d2e51 | ||
|
|
d9e1d9c0ef | ||
|
|
2fce02acd8 | ||
|
|
3e5baf583b | ||
|
|
4c3df00162 | ||
|
|
d2e02b3aae | ||
|
|
595f5da146 | ||
|
|
877d2a345b | ||
|
|
32d10036de | ||
|
|
e150621095 | ||
|
|
56d2db23e1 | ||
|
|
e7267d30f8 | ||
|
|
f621218ff5 | ||
|
|
65aa2ec849 | ||
|
|
1bc3476a46 | ||
|
|
269faa5fa4 | ||
|
|
845cd8c01e | ||
|
|
b2108a0de6 | ||
|
|
eb73549631 | ||
|
|
00a7379070 | ||
|
|
0f240a4c7e | ||
|
|
56e60ad5a6 | ||
|
|
b7f742ef56 | ||
|
|
c37e26a1a3 | ||
|
|
09b641aaa7 | ||
|
|
3d793ecf68 | ||
|
|
85669451d0 | ||
|
|
1cbf5cdac7 | ||
|
|
96bf37008e | ||
|
|
7b07b14e67 | ||
|
|
ec30dc3538 | ||
|
|
42413c5d85 | ||
|
|
f8cbf159e0 | ||
|
|
e879ea70e1 | ||
|
|
3d677f0870 | ||
|
|
bded1a8fcd | ||
|
|
5865899c81 | ||
|
|
1ced1b00ca | ||
|
|
f58c7de901 | ||
|
|
1503bd4f55 | ||
|
|
31611b62d7 | ||
|
|
6120a5f376 | ||
|
|
52c41b5b5a | ||
|
|
747c08e202 |
@@ -17,30 +17,6 @@ jobs:
|
||||
pre-commit run --all
|
||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||
|
||||
mlx_lm_build_and_test:
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.large.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.9
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install unittest-xml-reporting
|
||||
cd llms/
|
||||
pip install -e ".[test]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python -m xmlrunner discover -v llms/tests -o test-results/
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
@@ -48,7 +24,6 @@ workflows:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- mlx_lm_build_and_test
|
||||
- linux_build_and_test
|
||||
|
||||
prb:
|
||||
@@ -61,7 +36,5 @@ workflows:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mlx_lm_build_and_test:
|
||||
requires: [ hold ]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.8.0
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
|
||||
@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||
- Shiyu Li: Added the `Segment Anything Model`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`.
|
||||
@@ -4,12 +4,12 @@ This repo contains a variety of standalone examples using the [MLX
|
||||
framework](https://github.com/ml-explore/mlx).
|
||||
|
||||
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
||||
|
||||
Some more useful examples are listed below.
|
||||
Some more useful examples are listed below. Check-out [MLX
|
||||
LM](https://github.com/ml-explore/mlx-lm) for a more fully featured Python
|
||||
package for LLMs with MLX.
|
||||
|
||||
### Text Models
|
||||
|
||||
- [MLX LM](llms/README.md) a package for LLM text generation, fine-tuning, and more.
|
||||
- [Transformer language model](transformer_lm) training.
|
||||
- Minimal examples of large scale text generation with [LLaMA](llms/llama),
|
||||
[Mistral](llms/mistral), and more in the [LLMs](llms) directory.
|
||||
@@ -30,6 +30,7 @@ Some more useful examples are listed below.
|
||||
|
||||
- Speech recognition with [OpenAI's Whisper](whisper).
|
||||
- Audio compression and generation with [Meta's EnCodec](encodec).
|
||||
- Music generation with [Meta's MusicGen](musicgen).
|
||||
|
||||
### Multimodal models
|
||||
|
||||
|
||||
@@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
|
||||
|
||||
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
||||
We intend to update this example once these features are added.
|
||||
|
||||
## Distributed training
|
||||
|
||||
The example also supports distributed data parallel training. You can launch a
|
||||
distributed training as follows:
|
||||
|
||||
```shell
|
||||
$ cat >hostfile.json
|
||||
[
|
||||
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
|
||||
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
|
||||
]
|
||||
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
|
||||
```
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.data.datasets import load_cifar10
|
||||
|
||||
@@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
|
||||
x = x.astype("float32") / 255.0
|
||||
return (x - mean) / std
|
||||
|
||||
group = mx.distributed.init()
|
||||
|
||||
tr_iter = (
|
||||
tr.shuffle()
|
||||
.partition_if(group.size() > 1, group.size(), group.rank())
|
||||
.to_stream()
|
||||
.image_random_h_flip("image", prob=0.5)
|
||||
.pad("image", 0, 4, 4, 0.0)
|
||||
@@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
|
||||
)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
||||
test_iter = (
|
||||
test.to_stream()
|
||||
.partition_if(group.size() > 1, group.size(), group.rank())
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
|
||||
return tr_iter, test_iter
|
||||
|
||||
@@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||
|
||||
|
||||
def print_zero(group, *args, **kwargs):
|
||||
if group.rank() != 0:
|
||||
return
|
||||
flush = kwargs.pop("flush", True)
|
||||
print(*args, **kwargs, flush=flush)
|
||||
|
||||
|
||||
def eval_fn(model, inp, tgt):
|
||||
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||
|
||||
@@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||
return loss, acc
|
||||
|
||||
losses = []
|
||||
accs = []
|
||||
samples_per_sec = []
|
||||
world = mx.distributed.init()
|
||||
losses = 0
|
||||
accuracies = 0
|
||||
samples_per_sec = 0
|
||||
count = 0
|
||||
|
||||
def average_stats(stats, count):
|
||||
if world.size() == 1:
|
||||
return [s / count for s in stats]
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
stats = mx.distributed.all_sum(mx.array(stats))
|
||||
count = mx.distributed.all_sum(count)
|
||||
return (stats / count).tolist()
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
def step(inp, tgt):
|
||||
train_step_fn = nn.value_and_grad(model, train_step)
|
||||
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
||||
grads = nn.utils.average_gradients(grads)
|
||||
optimizer.update(model, grads)
|
||||
return loss, acc
|
||||
|
||||
@@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
y = mx.array(batch["label"])
|
||||
tic = time.perf_counter()
|
||||
loss, acc = step(x, y)
|
||||
mx.eval(state)
|
||||
mx.eval(loss, acc, state)
|
||||
toc = time.perf_counter()
|
||||
loss = loss.item()
|
||||
acc = acc.item()
|
||||
losses.append(loss)
|
||||
accs.append(acc)
|
||||
throughput = x.shape[0] / (toc - tic)
|
||||
samples_per_sec.append(throughput)
|
||||
losses += loss.item()
|
||||
accuracies += acc.item()
|
||||
samples_per_sec += x.shape[0] / (toc - tic)
|
||||
count += 1
|
||||
if batch_counter % 10 == 0:
|
||||
print(
|
||||
l, a, s = average_stats(
|
||||
[losses, accuracies, world.size() * samples_per_sec],
|
||||
count,
|
||||
)
|
||||
print_zero(
|
||||
world,
|
||||
" | ".join(
|
||||
(
|
||||
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||
f"Train loss {loss:.3f}",
|
||||
f"Train acc {acc:.3f}",
|
||||
f"Throughput: {throughput:.2f} images/second",
|
||||
f"Train loss {l:.3f}",
|
||||
f"Train acc {a:.3f}",
|
||||
f"Throughput: {s:.2f} images/second",
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
mean_tr_loss = mx.mean(mx.array(losses))
|
||||
mean_tr_acc = mx.mean(mx.array(accs))
|
||||
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
||||
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
||||
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
|
||||
|
||||
|
||||
def test_epoch(model, test_iter, epoch):
|
||||
accs = []
|
||||
accuracies = 0
|
||||
count = 0
|
||||
for batch_counter, batch in enumerate(test_iter):
|
||||
x = mx.array(batch["image"])
|
||||
y = mx.array(batch["label"])
|
||||
acc = eval_fn(model, x, y)
|
||||
acc_value = acc.item()
|
||||
accs.append(acc_value)
|
||||
mean_acc = mx.mean(mx.array(accs))
|
||||
return mean_acc
|
||||
accuracies += acc.item()
|
||||
count += 1
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
accuracies = mx.distributed.all_sum(accuracies)
|
||||
count = mx.distributed.all_sum(count)
|
||||
return (accuracies / count).item()
|
||||
|
||||
|
||||
def main(args):
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
# Initialize the distributed group and report the nodes that showed up
|
||||
world = mx.distributed.init()
|
||||
if world.size() > 1:
|
||||
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
|
||||
|
||||
model = getattr(resnet, args.arch)()
|
||||
|
||||
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
||||
print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
|
||||
|
||||
optimizer = optim.Adam(learning_rate=args.lr)
|
||||
|
||||
train_data, test_data = get_cifar10(args.batch_size)
|
||||
for epoch in range(args.epochs):
|
||||
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||
print(
|
||||
print_zero(
|
||||
world,
|
||||
" | ".join(
|
||||
(
|
||||
f"Epoch: {epoch}",
|
||||
f"avg. Train loss {tr_loss.item():.3f}",
|
||||
f"avg. Train acc {tr_acc.item():.3f}",
|
||||
f"Throughput: {throughput.item():.2f} images/sec",
|
||||
f"avg. Train loss {tr_loss:.3f}",
|
||||
f"avg. Train acc {tr_acc:.3f}",
|
||||
f"Throughput: {throughput:.2f} images/sec",
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
test_acc = test_epoch(model, test_data, epoch)
|
||||
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
|
||||
print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
|
||||
|
||||
train_data.reset()
|
||||
test_data.reset()
|
||||
|
||||
@@ -121,7 +121,7 @@ if __name__ == "__main__":
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("[INFO] Loading")
|
||||
torch_weights = torch.load(torch_path / "pytorch_model.bin")
|
||||
torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
|
||||
print("[INFO] Converting")
|
||||
mlx_weights = {
|
||||
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
||||
|
||||
156
esm/README.md
Normal file
156
esm/README.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# ESM-2
|
||||
|
||||
This repository provides an implementation of Meta's ESM-2 protein language model
|
||||
in MLX.[^1] ESM-2 is Meta’s second-generation Evolutionary Scale Model, a
|
||||
transformer-based protein language model trained on millions of diverse protein
|
||||
sequences with a masked language modeling objective.
|
||||
|
||||

|
||||
|
||||
_Example contact prediction map for a universal stress protein. In this case, ESM-2 650M achieves 86.4% precision at long-range contacts._
|
||||
|
||||
## Setup
|
||||
|
||||
Install the requirements:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Below are the available ESM-2 models:
|
||||
| Model | Parameters | Layers |
|
||||
|-------|------------|--------|
|
||||
| [`esm2_t6_8M_UR50D`](https://huggingface.co/facebook/esm2_t6_8M_UR50D) | 8M | 6 |
|
||||
| [`esm2_t12_35M_UR50D`](https://huggingface.co/facebook/esm2_t12_35M_UR50D) | 35M | 12 |
|
||||
| [`esm2_t30_150M_UR50D`](https://huggingface.co/facebook/esm2_t30_150M_UR50D) | 150M | 30 |
|
||||
| [`esm2_t33_650M_UR50D`](https://huggingface.co/facebook/esm2_t33_650M_UR50D) | 650M | 33 |
|
||||
| [`esm2_t36_3B_UR50D`](https://huggingface.co/facebook/esm2_t36_3B_UR50D) | 3B | 36 |
|
||||
| [`esm2_t48_15B_UR50D`](https://huggingface.co/facebook/esm2_t48_15B_UR50D) | 15B | 48 |
|
||||
|
||||
Convert a model to MLX format:
|
||||
|
||||
```bash
|
||||
python convert.py --hf-path facebook/esm2_t33_650M_UR50D
|
||||
```
|
||||
|
||||
This will save the converted model in a checkpoints directory.
|
||||
|
||||
### Basic Inference
|
||||
|
||||
```python
|
||||
from esm import ESM2
|
||||
|
||||
# Load model and tokenizer
|
||||
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
|
||||
|
||||
# Example protein sequence (human insulin)
|
||||
sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN"
|
||||
|
||||
# Tokenize and run inference
|
||||
tokens = tokenizer.encode(sequence)
|
||||
result = model(tokens)
|
||||
logits = result["logits"] # Shape: (batch, length, vocab_size)
|
||||
```
|
||||
|
||||
### Masked Language Modeling
|
||||
|
||||
```bash
|
||||
# For a complete example, see main.py
|
||||
python main.py --sequence "YOUR_SEQUENCE" --mask-position 50
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
|
||||
```python
|
||||
# Get sequence-level representations
|
||||
seq_repr = model.get_sequence_representations(tokens, layer=-1) # Shape: (batch, embed_dim)
|
||||
|
||||
# Extract per-residue representations from specific layers
|
||||
representations = model.extract_features(tokens, repr_layers=[20, 30, 33])
|
||||
final_layer = representations[33] # Shape: (batch, length, embed_dim)
|
||||
```
|
||||
|
||||
### Contact Prediction
|
||||
|
||||
```python
|
||||
# Predict residue-residue contacts
|
||||
contacts = model.predict_contacts(tokens) # Shape: (batch, length, length)
|
||||
|
||||
# Or compute contacts together with logits, representations, etc.
|
||||
outputs = model(tokens, return_contacts=True)
|
||||
contacts = outputs["contacts"]
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
**Mutation Effect Prediction**: [notebooks/mutation_effect_prediction.ipynb](notebooks/mutation_effect_prediction.ipynb)
|
||||
|
||||
This notebook demonstrates how to use ESM-2 for zero-shot mutation effect prediction by scoring amino acid substitutions based on their likelihood under the model. We validate the approach using experimental fitness data from β-lactamase TEM, showing how ESM-2 captures functional constraints without requiring structural information.
|
||||
|
||||
**Embeddings**: [notebooks/embeddings.ipynb](notebooks/embeddings.ipynb)
|
||||
|
||||
This notebook explores how ESM-2 generates meaningful protein embeddings that capture evolutionary and functional relationships between proteins. We analyze six diverse human proteins to demonstrate how the learned representations cluster proteins by function and reveal biological similarities.
|
||||
|
||||
**Contact Prediction**: [notebooks/contact_prediction.ipynb](notebooks/contact_prediction.ipynb)
|
||||
|
||||
This notebook shows how to predict residue-residue contacts in protein structures using ESM-2's attention patterns. We evaluate contact prediction performance on three diverse proteins, demonstrating how the model captures both local and long-range structural relationships directly from sequence data.
|
||||
|
||||
### Benchmarking
|
||||
|
||||
Benchmark MLX performance:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_mx.py
|
||||
```
|
||||
|
||||
Benchmark PyTorch MPS performance:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_pt.py
|
||||
```
|
||||
|
||||
Expected performance on M4 MacBook Pro (ESM-2 650M, batch_size = 5):
|
||||
|
||||
- MLX: 299 ms per step, 16.71 sequences/sec
|
||||
- PyTorch MPS: 402 ms per step, 12.43 sequences/sec
|
||||
|
||||
### Testing
|
||||
|
||||
Verify correctness against original implementation:
|
||||
|
||||
```bash
|
||||
python test.py
|
||||
```
|
||||
|
||||
This tests tokenizer and model outputs (logits, hidden states, and attentions) for equivalence with the original implementation.
|
||||
|
||||
### Citations:
|
||||
|
||||
```bibtex
|
||||
@article{rives2019biological,
|
||||
author={Rives, Alexander and Meier, Joshua and Sercu, Tom and Goyal, Siddharth and Lin, Zeming and Liu, Jason and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob},
|
||||
title={Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences},
|
||||
year={2019},
|
||||
doi={10.1101/622803},
|
||||
url={https://www.biorxiv.org/content/10.1101/622803v4},
|
||||
journal={PNAS}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Lin2023,
|
||||
author={Zeming Lin et al.},
|
||||
title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
|
||||
journal={Science},
|
||||
volume={379},
|
||||
pages={1123--1130},
|
||||
year={2023},
|
||||
doi={10.1126/science.ade2574},
|
||||
url={https://doi.org/10.1126/science.ade2574}
|
||||
}
|
||||
```
|
||||
|
||||
[^1]: Refer to the [paper](https://www.science.org/doi/10.1126/science.ade2574) and [code](https://github.com/facebookresearch/esm) for more details.
|
||||
BIN
esm/assets/contact_prediction.png
Normal file
BIN
esm/assets/contact_prediction.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 34 KiB |
47
esm/benchmarks/benchmark_mx.py
Normal file
47
esm/benchmarks/benchmark_mx.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Add parent directory to Python path
|
||||
cur_path = Path(__file__).parents[1].resolve()
|
||||
sys.path.append(str(cur_path))
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
# Example protein sequence (Green Fluorescent Protein)
|
||||
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
|
||||
|
||||
# Load pretrained ESM-2 model and its tokenizer from local checkpoint
|
||||
tokenizer, model = ESM2.from_pretrained("checkpoints/mlx-esm2_t33_650M_UR50D")
|
||||
|
||||
# Number of sequences to process in each forward pass
|
||||
batch_size = 5
|
||||
|
||||
# Number of timing iterations for performance measurement
|
||||
steps = 50
|
||||
|
||||
# Tokenize the protein sequence into integer IDs for the model
|
||||
# Replicate the same sequence 'batch_size' times to create a batch
|
||||
tokens = tokenizer.batch_encode([protein_sequence] * batch_size)
|
||||
|
||||
# Warm-up phase
|
||||
for _ in range(10):
|
||||
result = model(tokens)
|
||||
mx.eval(result["logits"]) # Force computation to complete
|
||||
|
||||
# Measure average inference time over 'steps' iterations
|
||||
tic = time.time()
|
||||
for _ in range(steps):
|
||||
result = model(tokens)
|
||||
mx.eval(result["logits"]) # Synchronize and ensure computation finishes
|
||||
toc = time.time()
|
||||
|
||||
# Compute metrics: average time per step (ms) and throughput (sequences/sec)
|
||||
ms_per_step = 1000 * (toc - tic) / steps
|
||||
throughput = batch_size * 1000 / ms_per_step
|
||||
|
||||
# Display results
|
||||
print(f"Time (ms) per step: {ms_per_step:.3f}")
|
||||
print(f"Throughput: {throughput:.2f} sequences/sec")
|
||||
52
esm/benchmarks/benchmark_pt.py
Normal file
52
esm/benchmarks/benchmark_pt.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, EsmForMaskedLM
|
||||
|
||||
# Example protein sequence (Green Fluorescent Protein)
|
||||
protein_sequence = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
|
||||
|
||||
# Hugging Face model identifier for ESM-2 (33 layers, 650M params, UR50D training set)
|
||||
model_name = "facebook/esm2_t33_650M_UR50D"
|
||||
|
||||
# Load tokenizer and model; move model to Apple Metal Performance Shaders (MPS) device
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = EsmForMaskedLM.from_pretrained(model_name).to("mps")
|
||||
|
||||
# Number of sequences per forward pass
|
||||
batch_size = 5
|
||||
|
||||
# Number of timing iterations
|
||||
steps = 50
|
||||
|
||||
# Tokenize input sequence and replicate for the batch
|
||||
# Replicate the same sequence 'batch_size' times to create a batch
|
||||
inputs = tokenizer(
|
||||
[protein_sequence] * batch_size,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1024,
|
||||
)
|
||||
input_ids = inputs["input_ids"].to("mps")
|
||||
attention_mask = inputs["attention_mask"].to("mps")
|
||||
|
||||
# Warm-up phase
|
||||
for _ in range(10):
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
torch.mps.synchronize() # Ensure all queued ops on MPS are complete before next step
|
||||
|
||||
# Timed inference loop
|
||||
tic = time.time()
|
||||
for _ in range(steps):
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
torch.mps.synchronize() # Wait for computation to finish before timing next iteration
|
||||
toc = time.time()
|
||||
|
||||
# Compute performance metrics
|
||||
ms_per_step = 1000 * (toc - tic) / steps
|
||||
throughput = batch_size * 1000 / ms_per_step
|
||||
|
||||
# Report results
|
||||
print(f"Time (ms) per step: {ms_per_step:.3f}")
|
||||
print(f"Throughput: {throughput:.2f} sequences/sec")
|
||||
177
esm/convert.py
Normal file
177
esm/convert.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def download(hf_repo: str) -> Path:
|
||||
"""Download model from Hugging Face."""
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=hf_repo,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.bin", "*.txt"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remap_key(key: str) -> str:
|
||||
"""Remap HuggingFace ESM key names to MLX format."""
|
||||
|
||||
# Skip position embeddings and position_ids
|
||||
if "position_embeddings" in key or "position_ids" in key:
|
||||
return None
|
||||
|
||||
# Map lm_head components properly
|
||||
if key == "lm_head.decoder.weight":
|
||||
return "lm_head.weight"
|
||||
if key == "lm_head.decoder.bias":
|
||||
return "lm_head.bias"
|
||||
if key == "lm_head.dense.weight":
|
||||
return "lm_head.dense.weight"
|
||||
if key == "lm_head.dense.bias":
|
||||
return "lm_head.dense.bias"
|
||||
if key == "lm_head.layer_norm.weight":
|
||||
return "lm_head.layer_norm.weight"
|
||||
if key == "lm_head.layer_norm.bias":
|
||||
return "lm_head.layer_norm.bias"
|
||||
|
||||
# Core remapping patterns
|
||||
key = key.replace("esm.embeddings.word_embeddings", "embed_tokens")
|
||||
key = key.replace("esm.encoder.emb_layer_norm_after", "emb_layer_norm_after")
|
||||
key = key.replace("esm.encoder.layer.", "layer_")
|
||||
key = key.replace("esm.contact_head", "contact_head")
|
||||
key = key.replace("lm_head", "lm_head")
|
||||
|
||||
# Attention patterns
|
||||
key = key.replace(".attention.self.", ".self_attn.")
|
||||
key = key.replace(".attention.output.dense", ".self_attn.out_proj")
|
||||
key = key.replace(".attention.LayerNorm", ".self_attn_layer_norm")
|
||||
key = key.replace(".query", ".q_proj")
|
||||
key = key.replace(".key", ".k_proj")
|
||||
key = key.replace(".value", ".v_proj")
|
||||
key = key.replace(".rotary_embeddings", ".rot_emb")
|
||||
|
||||
# FFN patterns
|
||||
key = key.replace(".intermediate.dense", ".fc1")
|
||||
key = key.replace(".output.dense", ".fc2")
|
||||
key = key.replace(".LayerNorm", ".final_layer_norm")
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def load_weights(model_path: Path) -> Dict:
|
||||
"""Load weights from safetensors or PyTorch bin files."""
|
||||
|
||||
# Check for safetensors file
|
||||
safetensors_path = model_path / "model.safetensors"
|
||||
if safetensors_path.exists():
|
||||
print("Loading from safetensors...")
|
||||
return mx.load(str(safetensors_path))
|
||||
|
||||
# Check for single bin file
|
||||
single_bin_path = model_path / "pytorch_model.bin"
|
||||
if single_bin_path.exists():
|
||||
print("Loading from pytorch_model.bin...")
|
||||
state_dict = torch.load(str(single_bin_path), map_location="cpu")
|
||||
return {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
# Check for sharded bin files
|
||||
index_file = model_path / "pytorch_model.bin.index.json"
|
||||
if index_file.exists():
|
||||
print("Loading from sharded bin files...")
|
||||
with open(index_file) as f:
|
||||
index = json.load(f)
|
||||
|
||||
# Get unique shard files
|
||||
shard_files = set(index["weight_map"].values())
|
||||
|
||||
# Load all shards
|
||||
state_dict = {}
|
||||
for shard_file in sorted(shard_files):
|
||||
print(f" Loading shard: {shard_file}")
|
||||
shard_path = model_path / shard_file
|
||||
shard_dict = torch.load(str(shard_path), map_location="cpu")
|
||||
state_dict.update(shard_dict)
|
||||
|
||||
return {k: v.numpy() for k, v in state_dict.items()}
|
||||
|
||||
raise ValueError(f"No model weights found in {model_path}")
|
||||
|
||||
|
||||
def convert(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Convert ESM weights to MLX format."""
|
||||
|
||||
# Load weights from any format
|
||||
weights = load_weights(model_path)
|
||||
|
||||
# Convert keys and create MLX arrays
|
||||
mlx_weights = {}
|
||||
for key, value in weights.items():
|
||||
mlx_key = remap_key(key)
|
||||
if mlx_key is not None:
|
||||
mlx_weights[mlx_key] = (
|
||||
mx.array(value) if not isinstance(value, mx.array) else value
|
||||
)
|
||||
|
||||
# If lm_head.weight is missing but embed_tokens.weight exists, set up weight sharing
|
||||
# (This is for smaller models that don't have a separate lm_head.decoder.weight)
|
||||
if "lm_head.weight" not in mlx_weights and "embed_tokens.weight" in mlx_weights:
|
||||
mlx_weights["lm_head.weight"] = mlx_weights["embed_tokens.weight"]
|
||||
|
||||
return mlx_weights
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert ESM weights to MLX format")
|
||||
parser.add_argument(
|
||||
"--hf-path", default="facebook/esm2_t6_8M_UR50D", help="Hugging Face model path"
|
||||
)
|
||||
parser.add_argument("--mlx-path", default=None, help="Output path for MLX model")
|
||||
parser.add_argument(
|
||||
"--checkpoints-dir",
|
||||
default="checkpoints",
|
||||
help="Directory to save checkpoints (default: checkpoints)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Download model
|
||||
print(f"Downloading {args.hf_path}...")
|
||||
model_path = download(args.hf_path)
|
||||
|
||||
# Set output path
|
||||
if args.mlx_path is None:
|
||||
model_name = args.hf_path.split("/")[-1]
|
||||
checkpoints_dir = Path(args.checkpoints_dir)
|
||||
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.mlx_path = checkpoints_dir / f"mlx-{model_name}"
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert weights
|
||||
print("Converting weights...")
|
||||
mlx_weights = convert(model_path)
|
||||
|
||||
# Save weights
|
||||
print(f"Saving MLX weights to {mlx_path}...")
|
||||
mx.save_safetensors(str(mlx_path / "model.safetensors"), mlx_weights)
|
||||
|
||||
# Copy config and other files
|
||||
print("Copying config...")
|
||||
shutil.copy(model_path / "config.json", mlx_path / "config.json")
|
||||
|
||||
for file_name in ["special_tokens_map.json", "tokenizer.json", "vocab.txt"]:
|
||||
src_file = model_path / file_name
|
||||
if src_file.exists():
|
||||
shutil.copy(src_file, mlx_path / file_name)
|
||||
|
||||
print(f"Conversion complete! MLX model saved to {mlx_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
esm/esm/__init__.py
Normal file
19
esm/esm/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
ESM-2 protein language model implementation in MLX
|
||||
"""
|
||||
|
||||
from .attention import MultiheadAttention
|
||||
from .model import ESM2
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
from .tokenizer import ProteinTokenizer
|
||||
|
||||
__all__ = [
|
||||
"ESM2",
|
||||
"ProteinTokenizer",
|
||||
"ContactPredictionHead",
|
||||
"RobertaLMHead",
|
||||
"TransformerLayer",
|
||||
"MultiheadAttention",
|
||||
"RotaryEmbedding",
|
||||
]
|
||||
153
esm/esm/attention.py
Normal file
153
esm/esm/attention.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""
|
||||
Multi-head attention layer with rotary position embeddings, as used in ESM-2.
|
||||
|
||||
This module implements both self-attention (when `key` and `value` are not
|
||||
provided) and cross-attention. It projects input sequences into queries,
|
||||
keys, and values, applies rotary position embeddings to encode relative
|
||||
position information, computes scaled dot-product attention over multiple
|
||||
heads in parallel, and returns a combined output projection.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Total embedding dimension of the model input and output.
|
||||
num_heads (int): Number of parallel attention heads. Must divide `embed_dim`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
# Linear projections for queries, keys, and values (with bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
# Linear projection for output (with bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
# ESM-2 uses rotary embeddings
|
||||
self.rot_emb = RotaryEmbedding(dim=self.head_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
query,
|
||||
key: Optional[mx.array] = None,
|
||||
value: Optional[mx.array] = None,
|
||||
key_padding_mask: Optional[mx.array] = None,
|
||||
attn_mask: Optional[mx.array] = None,
|
||||
need_head_weights: bool = False,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
"""
|
||||
Multi-head attention forward pass.
|
||||
|
||||
Args:
|
||||
query: Tensor of shape (tgt_len, batch, embed_dim).
|
||||
key: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
|
||||
value: Optional tensor of shape (src_len, batch, embed_dim). Defaults to `query`.
|
||||
key_padding_mask: Optional mask of shape (batch, src_len) to ignore padded positions.
|
||||
attn_mask: Optional mask for attention (e.g., causal mask).
|
||||
need_head_weights: If True, return attention weights for each head separately.
|
||||
|
||||
Returns:
|
||||
attn_output: Tensor of shape (tgt_len, batch, embed_dim).
|
||||
attn_weights_out: Attention weights of shape
|
||||
(num_heads, batch, tgt_len, src_len) if per-head,
|
||||
or (batch, tgt_len, src_len) if averaged.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.shape
|
||||
assert embed_dim == self.embed_dim
|
||||
|
||||
# For self-attention, use query as key and value if not provided
|
||||
if key is None:
|
||||
key = query
|
||||
if value is None:
|
||||
value = query
|
||||
|
||||
# Project queries, keys, values
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
|
||||
q = q * self.scaling
|
||||
|
||||
# Reshape for multi-head attention
|
||||
q = q.reshape(tgt_len, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
k = k.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
v = v.reshape(-1, bsz * self.num_heads, self.head_dim).swapaxes(0, 1)
|
||||
|
||||
src_len = k.shape[1]
|
||||
|
||||
# Apply rotary embeddings if present
|
||||
if self.rot_emb:
|
||||
q, k = self.rot_emb(q, k)
|
||||
|
||||
# Compute attention weights
|
||||
attn_weights = q @ k.swapaxes(-2, -1)
|
||||
|
||||
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
# Apply attention mask
|
||||
if attn_mask is not None:
|
||||
attn_mask = mx.expand_dims(attn_mask, 0)
|
||||
attn_weights = attn_weights + attn_mask
|
||||
|
||||
# Apply key padding mask
|
||||
if key_padding_mask is not None:
|
||||
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
# Convert key_padding_mask to boolean and expand dimensions
|
||||
# key_padding_mask: [bsz, src_len] -> [bsz, 1, 1, src_len]
|
||||
mask = mx.expand_dims(
|
||||
mx.expand_dims(key_padding_mask.astype(mx.bool_), 1), 2
|
||||
)
|
||||
# Apply mask: set attention to -inf where mask is True (padded positions)
|
||||
attn_weights = mx.where(mask, -mx.inf, attn_weights)
|
||||
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
# Apply softmax
|
||||
attn_weights_float = mx.softmax(attn_weights.astype(mx.float32), axis=-1)
|
||||
attn_probs = attn_weights_float
|
||||
|
||||
# Compute attention output
|
||||
attn = attn_probs @ v
|
||||
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
|
||||
# Reshape output
|
||||
attn = attn.swapaxes(0, 1).reshape(tgt_len, bsz, embed_dim)
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
# Return attention weights if requested
|
||||
attn_weights_out: Optional[mx.array] = None
|
||||
if need_head_weights:
|
||||
# Return attention weights for each head separately
|
||||
attn_weights_out = (
|
||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len)
|
||||
.astype(attn.dtype)
|
||||
.swapaxes(0, 1)
|
||||
)
|
||||
else:
|
||||
# Return averaged attention weights
|
||||
attn_weights_out = mx.mean(
|
||||
attn_weights_float.reshape(bsz, self.num_heads, tgt_len, src_len),
|
||||
axis=1,
|
||||
).astype(attn.dtype)
|
||||
|
||||
return attn, attn_weights_out
|
||||
340
esm/esm/model.py
Normal file
340
esm/esm/model.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .modules import ContactPredictionHead, RobertaLMHead, TransformerLayer
|
||||
from .tokenizer import ProteinTokenizer
|
||||
|
||||
|
||||
class ESM2(nn.Module):
|
||||
"""
|
||||
ESM-2 protein language model in MLX.
|
||||
|
||||
Args:
|
||||
num_layers (int): Number of transformer layers.
|
||||
embed_dim (int): Embedding dimension.
|
||||
attention_heads (int): Number of attention heads.
|
||||
tokenizer (Optional[ProteinTokenizer]): Tokenizer to use (created if None).
|
||||
token_dropout (bool): Apply token-dropout masking behavior.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 33,
|
||||
embed_dim: int = 1280,
|
||||
attention_heads: int = 20,
|
||||
tokenizer: Optional[ProteinTokenizer] = None,
|
||||
token_dropout: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.embed_dim = embed_dim
|
||||
self.attention_heads = attention_heads
|
||||
|
||||
# Initialize tokenizer
|
||||
if tokenizer is None:
|
||||
tokenizer = ProteinTokenizer()
|
||||
self.tokenizer = tokenizer
|
||||
self.vocab_size = len(tokenizer)
|
||||
|
||||
# Special token IDs / config
|
||||
self.padding_idx = tokenizer.pad_id
|
||||
self.mask_idx = tokenizer.mask_id
|
||||
self.cls_idx = tokenizer.cls_id
|
||||
self.eos_idx = tokenizer.eos_id
|
||||
self.prepend_bos = tokenizer.prepend_bos
|
||||
self.append_eos = tokenizer.append_eos
|
||||
self.token_dropout = token_dropout
|
||||
|
||||
self._init_submodules()
|
||||
|
||||
def _init_submodules(self) -> None:
|
||||
"""Initialize embeddings, transformer stack, and output heads."""
|
||||
self.embed_scale = 1
|
||||
|
||||
# Token embeddings
|
||||
self.embed_tokens = nn.Embedding(self.vocab_size, self.embed_dim)
|
||||
|
||||
# Transformer layers (register each layer so MLX tracks parameters)
|
||||
self._layer_indices = list(range(self.num_layers))
|
||||
for i in self._layer_indices:
|
||||
layer = TransformerLayer(
|
||||
self.embed_dim,
|
||||
4 * self.embed_dim, # FFN dimension = 4×embed_dim
|
||||
self.attention_heads,
|
||||
)
|
||||
setattr(self, f"layer_{i}", layer)
|
||||
|
||||
# Contact prediction head (uses all layers × heads attentions)
|
||||
self.contact_head = ContactPredictionHead(
|
||||
self.num_layers * self.attention_heads,
|
||||
self.prepend_bos,
|
||||
self.append_eos,
|
||||
eos_idx=self.eos_idx,
|
||||
)
|
||||
|
||||
# Final norm + LM head (tied weights)
|
||||
self.emb_layer_norm_after = nn.LayerNorm(self.embed_dim)
|
||||
self.lm_head = RobertaLMHead(
|
||||
embed_dim=self.embed_dim,
|
||||
output_dim=self.vocab_size,
|
||||
weight=self.embed_tokens.weight,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tokens: mx.array,
|
||||
repr_layers: List[int] = [],
|
||||
need_head_weights: bool = False,
|
||||
return_contacts: bool = False,
|
||||
) -> Dict[str, mx.array]:
|
||||
"""
|
||||
Forward pass through ESM-2.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of token IDs with shape (B, T).
|
||||
repr_layers: Layers to return hidden states from (0..num_layers).
|
||||
need_head_weights: If True, return attention weights.
|
||||
return_contacts: If True, compute residue-residue contact probabilities.
|
||||
|
||||
Returns:
|
||||
dict with:
|
||||
logits: (B, T, V)
|
||||
representations: {layer_idx: (B, T, E)}
|
||||
attentions: (B, L, H, T, T) if requested
|
||||
contacts: (B, T', T') if requested
|
||||
"""
|
||||
if return_contacts:
|
||||
need_head_weights = True
|
||||
|
||||
# Ensure tokens is 2D (B, T)
|
||||
if tokens.ndim == 1:
|
||||
tokens = mx.expand_dims(tokens, axis=0)
|
||||
assert tokens.ndim == 2
|
||||
|
||||
# Padding mask (B, T)
|
||||
padding_mask = mx.equal(tokens, self.padding_idx)
|
||||
|
||||
# Embed tokens (B, T, E)
|
||||
x = self.embed_scale * self.embed_tokens(tokens)
|
||||
|
||||
# Token dropout: zero masked tokens + rescale based on observed mask ratio
|
||||
if self.token_dropout:
|
||||
# x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
|
||||
mask_positions = mx.equal(tokens, self.mask_idx)
|
||||
x = mx.where(mask_positions[:, :, None], 0.0, x)
|
||||
|
||||
# x: B x T x C
|
||||
mask_ratio_train = 0.15 * 0.8
|
||||
src_lengths = mx.sum(~padding_mask, axis=-1) # Shape: (B,)
|
||||
mask_ratio_observed = mx.sum(mask_positions, axis=-1).astype(x.dtype) / src_lengths # Shape: (B,)
|
||||
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
|
||||
|
||||
# Zero out padding positions
|
||||
if padding_mask.any():
|
||||
x = x * (1 - padding_mask[:, :, None].astype(x.dtype))
|
||||
|
||||
# Track requested representations
|
||||
repr_layers = set(repr_layers)
|
||||
hidden_representations: Dict[int, mx.array] = {}
|
||||
if 0 in repr_layers:
|
||||
hidden_representations[0] = x
|
||||
|
||||
if need_head_weights:
|
||||
attn_weights: List[mx.array] = []
|
||||
|
||||
# (B, T, E) -> (T, B, E) for transformer layers
|
||||
x = mx.swapaxes(x, 0, 1)
|
||||
|
||||
# If no padding anywhere, skip the mask
|
||||
if not padding_mask.any():
|
||||
padding_mask = None
|
||||
|
||||
# Transformer stack
|
||||
for layer_idx in self._layer_indices:
|
||||
layer = getattr(self, f"layer_{layer_idx}")
|
||||
x, attn = layer(
|
||||
x,
|
||||
self_attn_padding_mask=padding_mask,
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
|
||||
# Save hidden representation if requested (store back as (B, T, E))
|
||||
if (layer_idx + 1) in repr_layers:
|
||||
hidden_representations[layer_idx + 1] = mx.swapaxes(x, 0, 1)
|
||||
|
||||
# Save per-layer attentions if requested (H, B, T, T) -> (B, H, T, T)
|
||||
if need_head_weights:
|
||||
attn_weights.append(mx.swapaxes(attn, 0, 1))
|
||||
|
||||
# Final layer norm, back to (B, T, E)
|
||||
x = self.emb_layer_norm_after(x)
|
||||
x = mx.swapaxes(x, 0, 1)
|
||||
|
||||
# Save final hidden if requested
|
||||
if (layer_idx + 1) in repr_layers:
|
||||
hidden_representations[layer_idx + 1] = x
|
||||
|
||||
# Language modeling logits
|
||||
x = self.lm_head(x)
|
||||
|
||||
# Build result dict
|
||||
result: Dict[str, mx.array] = {
|
||||
"logits": x,
|
||||
"representations": hidden_representations,
|
||||
}
|
||||
|
||||
# Collect attentions and optional contacts
|
||||
if need_head_weights:
|
||||
# Stack layers -> (B, L, H, T, T)
|
||||
attentions = mx.stack(attn_weights, axis=1)
|
||||
|
||||
# Mask out padded positions if present
|
||||
if padding_mask is not None:
|
||||
attention_mask = 1 - padding_mask.astype(attentions.dtype)
|
||||
attention_mask = mx.expand_dims(attention_mask, 1) * mx.expand_dims(
|
||||
attention_mask, 2
|
||||
)
|
||||
attentions = attentions * attention_mask[:, None, None, :, :]
|
||||
|
||||
result["attentions"] = attentions
|
||||
|
||||
# Compute contacts if requested
|
||||
if return_contacts:
|
||||
contacts = self.contact_head(tokens, attentions)
|
||||
result["contacts"] = contacts
|
||||
|
||||
return result
|
||||
|
||||
def predict_contacts(self, tokens: mx.array) -> mx.array:
|
||||
"""
|
||||
Predict residue-residue contacts.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
mx.array: Contact probabilities of shape (B, T', T').
|
||||
"""
|
||||
return self(tokens, return_contacts=True)["contacts"]
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
tokens: mx.array,
|
||||
repr_layers: Optional[List[int]] = None,
|
||||
return_all_hiddens: bool = False,
|
||||
) -> Dict[int, mx.array]:
|
||||
"""
|
||||
Extract hidden representations from selected layers.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of shape (B, T).
|
||||
repr_layers: Layer indices to return (default: last layer).
|
||||
return_all_hiddens: If True, return all layers (0..num_layers).
|
||||
|
||||
Returns:
|
||||
dict: {layer_idx: (B, T, E)} for requested layers.
|
||||
"""
|
||||
if return_all_hiddens:
|
||||
repr_layers = list(range(self.num_layers + 1))
|
||||
elif repr_layers is None:
|
||||
repr_layers = [self.num_layers]
|
||||
|
||||
result = self(tokens, repr_layers=repr_layers)
|
||||
return result["representations"]
|
||||
|
||||
def get_sequence_representations(
|
||||
self,
|
||||
tokens: mx.array,
|
||||
layer: int = -1,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Average token representations into a per-sequence embedding.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of shape (B, T).
|
||||
layer: Layer index to use (-1 or num_layers for last).
|
||||
|
||||
Returns:
|
||||
mx.array: Sequence embeddings of shape (B, E).
|
||||
"""
|
||||
if layer == -1:
|
||||
layer = self.num_layers
|
||||
|
||||
representations = self.extract_features(tokens, repr_layers=[layer])
|
||||
repr = representations[layer]
|
||||
|
||||
# Mask: non-padding and not CLS; optionally not EOS
|
||||
mask = mx.logical_and(
|
||||
mx.not_equal(tokens, self.padding_idx),
|
||||
mx.not_equal(tokens, self.cls_idx),
|
||||
)
|
||||
if self.append_eos:
|
||||
mask = mx.logical_and(mask, mx.not_equal(tokens, self.eos_idx))
|
||||
|
||||
# Mean over valid positions
|
||||
mask = mask[:, :, None].astype(repr.dtype)
|
||||
masked_repr = repr * mask
|
||||
seq_lens = mx.sum(mask, axis=1, keepdims=True)
|
||||
seq_repr = mx.sum(masked_repr, axis=1) / mx.maximum(seq_lens[:, :, 0], 1.0)
|
||||
|
||||
return seq_repr
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str) -> Tuple[ProteinTokenizer, "ESM2"]:
|
||||
"""
|
||||
Load model weights and config from a directory.
|
||||
|
||||
Expects:
|
||||
- config.json
|
||||
- model.safetensors
|
||||
- vocab.txt (optional, will use default if not found)
|
||||
- special_tokens_map.json (optional, will use default if not found)
|
||||
|
||||
Args:
|
||||
model_path: Path to directory with weights and config.
|
||||
|
||||
Returns:
|
||||
(tokenizer, model): Initialized tokenizer and ESM2 model.
|
||||
"""
|
||||
model_dir = Path(model_path)
|
||||
config_path = model_dir / "config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Check for vocab and special tokens files
|
||||
vocab_path = model_dir / "vocab.txt"
|
||||
special_tokens_path = model_dir / "special_tokens_map.json"
|
||||
|
||||
if vocab_path.exists() and special_tokens_path.exists():
|
||||
tokenizer = ProteinTokenizer(
|
||||
vocab_file=str(vocab_path),
|
||||
special_tokens_map_file=str(special_tokens_path),
|
||||
)
|
||||
else:
|
||||
tokenizer = ProteinTokenizer()
|
||||
|
||||
model = cls(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
embed_dim=config["hidden_size"],
|
||||
attention_heads=config["num_attention_heads"],
|
||||
tokenizer=tokenizer,
|
||||
token_dropout=config["token_dropout"],
|
||||
)
|
||||
|
||||
# Load safetensors as nested dict and update model params
|
||||
weights_path = model_dir / "model.safetensors"
|
||||
flat_weights = mx.load(str(weights_path))
|
||||
nested_weights: Dict[str, dict] = {}
|
||||
for key, value in flat_weights.items():
|
||||
parts = key.split(".")
|
||||
cur = nested_weights
|
||||
for p in parts[:-1]:
|
||||
cur = cur.setdefault(p, {})
|
||||
cur[parts[-1]] = value
|
||||
|
||||
model.update(nested_weights)
|
||||
return tokenizer, model
|
||||
212
esm/esm/modules.py
Normal file
212
esm/esm/modules.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import MultiheadAttention
|
||||
|
||||
|
||||
def symmetrize(x: mx.array) -> mx.array:
|
||||
"""
|
||||
Make a tensor symmetric over its last two dimensions.
|
||||
|
||||
Args:
|
||||
x: Tensor with shape (..., L, L).
|
||||
|
||||
Returns:
|
||||
mx.array: Symmetrized tensor of shape (..., L, L).
|
||||
"""
|
||||
# Add tensor to its transpose over the last two dims
|
||||
return x + mx.swapaxes(x, -1, -2)
|
||||
|
||||
|
||||
def apc(x: mx.array) -> mx.array:
|
||||
"""
|
||||
Apply Average Product Correction (APC) to remove background co-variation.
|
||||
|
||||
Args:
|
||||
x: Tensor with shape (..., L, L).
|
||||
|
||||
Returns:
|
||||
mx.array: APC-corrected tensor of shape (..., L, L).
|
||||
"""
|
||||
# Compute row, column, and total sums
|
||||
a1 = mx.sum(x, axis=-1, keepdims=True)
|
||||
a2 = mx.sum(x, axis=-2, keepdims=True)
|
||||
a12 = mx.sum(x, axis=(-1, -2), keepdims=True)
|
||||
|
||||
# Expected co-variation under independence
|
||||
expected = (a1 * a2) / a12
|
||||
return x - expected
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
"""
|
||||
Transformer layer used in ESM-2.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Model embedding dimension.
|
||||
ffn_embed_dim (int): Hidden dimension of the feed-forward network.
|
||||
attention_heads (int): Number of attention heads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
ffn_embed_dim: int,
|
||||
attention_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.ffn_embed_dim = ffn_embed_dim
|
||||
self.attention_heads = attention_heads
|
||||
self._init_submodules()
|
||||
|
||||
def _init_submodules(self) -> None:
|
||||
"""Initialize attention, norms, and feed-forward submodules."""
|
||||
self.self_attn = MultiheadAttention(self.embed_dim, self.attention_heads)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
|
||||
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
self_attn_mask: Optional[mx.array] = None,
|
||||
self_attn_padding_mask: Optional[mx.array] = None,
|
||||
need_head_weights: bool = False,
|
||||
):
|
||||
"""
|
||||
Forward pass for the Transformer layer.
|
||||
|
||||
Args:
|
||||
x: Tensor of shape (seq_len, batch, embed_dim).
|
||||
self_attn_mask: Optional attention mask.
|
||||
self_attn_padding_mask: Optional padding mask of shape (batch, seq_len).
|
||||
need_head_weights: If True, return per-head attention weights.
|
||||
|
||||
Returns:
|
||||
x: Tensor of shape (seq_len, batch, embed_dim).
|
||||
attn: Attention weights of shape
|
||||
(num_heads, batch, tgt_len, src_len) if per-head,
|
||||
or (batch, tgt_len, src_len) if averaged.
|
||||
"""
|
||||
# Self-attention block
|
||||
residual = x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
attn_mask=self_attn_mask,
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
# Feed-forward block
|
||||
residual = x
|
||||
x = self.final_layer_norm(x)
|
||||
x = nn.gelu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
x = residual + x
|
||||
|
||||
return x, attn
|
||||
|
||||
|
||||
class RobertaLMHead(nn.Module):
|
||||
"""
|
||||
Masked Language Modeling (MLM) head with tied weights.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Embedding dimension of the backbone.
|
||||
output_dim (int): Vocabulary size.
|
||||
weight (mx.array): Embedding weight matrix for tied projection.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim: int, output_dim: int, weight: mx.array):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(embed_dim, embed_dim)
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
self.weight = weight
|
||||
self.bias = mx.zeros(output_dim)
|
||||
|
||||
def __call__(self, features: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass for the MLM head.
|
||||
|
||||
Args:
|
||||
features: Tensor of shape (seq_len, batch, embed_dim).
|
||||
|
||||
Returns:
|
||||
mx.array: Logits of shape (seq_len, batch, output_dim).
|
||||
"""
|
||||
# Transform features before projection to vocab
|
||||
x = self.dense(features)
|
||||
x = nn.gelu(x)
|
||||
x = self.layer_norm(x)
|
||||
return mx.matmul(x, self.weight.T) + self.bias
|
||||
|
||||
|
||||
class ContactPredictionHead(nn.Module):
|
||||
"""
|
||||
Predict residue-residue contact probabilities from attention maps.
|
||||
|
||||
Args:
|
||||
in_features (int): Number of attention channels (layers × heads).
|
||||
prepend_bos (bool): If True, drop BOS/CLS token attentions.
|
||||
append_eos (bool): If True, drop EOS token attentions.
|
||||
bias (bool): Whether the regression layer uses a bias term.
|
||||
eos_idx (Optional[int]): Token ID for EOS; required if append_eos=True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
prepend_bos: bool,
|
||||
append_eos: bool,
|
||||
bias: bool = True,
|
||||
eos_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.prepend_bos = prepend_bos
|
||||
self.append_eos = append_eos
|
||||
if append_eos and eos_idx is None:
|
||||
raise ValueError("append_eos=True but eos_idx was not provided.")
|
||||
self.eos_idx = eos_idx
|
||||
self.regression = nn.Linear(in_features, 1, bias=bias)
|
||||
|
||||
def __call__(self, tokens: mx.array, attentions: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass for contact prediction.
|
||||
|
||||
Args:
|
||||
tokens: Tensor of shape (B, T).
|
||||
attentions: Tensor of shape (B, L, H, T, T).
|
||||
|
||||
Returns:
|
||||
mx.array: Contact probabilities of shape (B, T', T'),
|
||||
where T' = T - [prepend_bos] - [append_eos].
|
||||
"""
|
||||
# Remove EOS attentions if requested
|
||||
if self.append_eos:
|
||||
eos_mask = mx.not_equal(tokens, self.eos_idx).astype(attentions.dtype)
|
||||
eos_mask = eos_mask[:, None, :] * eos_mask[:, :, None]
|
||||
attentions = attentions * eos_mask[:, None, None, :, :]
|
||||
attentions = attentions[..., :-1, :-1]
|
||||
|
||||
# Remove BOS attentions if requested
|
||||
if self.prepend_bos:
|
||||
attentions = attentions[..., 1:, 1:]
|
||||
|
||||
# Merge (layers × heads) into channel dimension
|
||||
batch_size, layers, heads, seqlen, _ = attentions.shape
|
||||
attentions = attentions.reshape(batch_size, layers * heads, seqlen, seqlen)
|
||||
|
||||
# Symmetrize and apply APC to enhance contact signal
|
||||
attentions = apc(symmetrize(attentions))
|
||||
|
||||
# Apply logistic regression over channels
|
||||
attentions = mx.transpose(attentions, axes=[0, 2, 3, 1])
|
||||
logits = self.regression(attentions)
|
||||
return nn.sigmoid(mx.squeeze(logits, axis=3))
|
||||
114
esm/esm/rotary_embedding.py
Normal file
114
esm/esm/rotary_embedding.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def rotate_half(x: mx.array) -> mx.array:
|
||||
"""
|
||||
Rotate last dimension by splitting into two halves and swapping.
|
||||
|
||||
Args:
|
||||
x: Tensor with even-sized last dimension.
|
||||
|
||||
Returns:
|
||||
mx.array: Tensor of same shape as `x` with halves rotated.
|
||||
"""
|
||||
# Split into two equal halves along the last dimension
|
||||
x1, x2 = mx.split(x, 2, axis=-1)
|
||||
# Swap halves and negate the second half
|
||||
return mx.concatenate((-x2, x1), axis=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
|
||||
"""
|
||||
Apply rotary position embeddings to a tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (..., seq_len, dim).
|
||||
cos: Cosine embedding table of shape (1, seq_len, dim).
|
||||
sin: Sine embedding table of shape (1, seq_len, dim).
|
||||
|
||||
Returns:
|
||||
mx.array: Tensor with rotary position embeddings applied.
|
||||
"""
|
||||
# Trim cos/sin to match the sequence length of x
|
||||
cos = cos[:, : x.shape[-2], :]
|
||||
sin = sin[:, : x.shape[-2], :]
|
||||
|
||||
# Elementwise rotation: (x * cos) + (rotate_half(x) * sin)
|
||||
return (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""
|
||||
Rotary position embedding (RoPE) module.
|
||||
|
||||
Args:
|
||||
dim (int): Head dimension size (must be even).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, *_, **__):
|
||||
super().__init__()
|
||||
# Precompute inverse frequency for each pair of dimensions
|
||||
self.inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
|
||||
|
||||
# Cache for cosine/sine tables to avoid recomputation
|
||||
self._seq_len_cached = None
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_tables(
|
||||
self, x: mx.array, seq_dimension: int = 1
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Compute and cache cos/sin tables for the given sequence length.
|
||||
|
||||
Args:
|
||||
x: Reference tensor for sequence length.
|
||||
seq_dimension: Axis containing the sequence length.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
cos: Cosine table of shape (1, seq_len, dim).
|
||||
sin: Sine table of shape (1, seq_len, dim).
|
||||
"""
|
||||
seq_len = x.shape[seq_dimension]
|
||||
# Only update cache if sequence length has changed
|
||||
if seq_len != self._seq_len_cached:
|
||||
self._seq_len_cached = seq_len
|
||||
# Time steps: shape (seq_len,)
|
||||
t = mx.arange(seq_len).astype(self.inv_freq.dtype)
|
||||
# Outer product between time and inverse frequency
|
||||
freqs = mx.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Duplicate frequencies for cos/sin dimensions
|
||||
emb = mx.concatenate((freqs, freqs), axis=-1)
|
||||
|
||||
self._cos_cached = mx.cos(emb)[None, :, :]
|
||||
self._sin_cached = mx.sin(emb)[None, :, :]
|
||||
|
||||
return self._cos_cached, self._sin_cached
|
||||
|
||||
def __call__(self, q: mx.array, k: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Apply rotary position embeddings to queries and keys.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape (..., seq_len, dim).
|
||||
k: Key tensor of shape (..., seq_len, dim).
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
q_rot: Query tensor with RoPE applied.
|
||||
k_rot: Key tensor with RoPE applied.
|
||||
"""
|
||||
# Get (and cache) cos/sin tables based on key sequence length
|
||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
|
||||
k, seq_dimension=-2
|
||||
)
|
||||
|
||||
# Apply rotary embeddings to both q and k
|
||||
return (
|
||||
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
|
||||
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
|
||||
)
|
||||
241
esm/esm/tokenizer.py
Normal file
241
esm/esm/tokenizer.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# Canonical amino-acid tokens (IUPAC standard + uncommon variants)
|
||||
PROTEIN_TOKENS = [
|
||||
"L",
|
||||
"A",
|
||||
"G",
|
||||
"V",
|
||||
"S",
|
||||
"E",
|
||||
"R",
|
||||
"T",
|
||||
"I",
|
||||
"D",
|
||||
"P",
|
||||
"K",
|
||||
"Q",
|
||||
"N",
|
||||
"F",
|
||||
"Y",
|
||||
"M",
|
||||
"H",
|
||||
"W",
|
||||
"C",
|
||||
"X",
|
||||
"B",
|
||||
"U",
|
||||
"Z",
|
||||
"O",
|
||||
".",
|
||||
"-",
|
||||
]
|
||||
|
||||
ArrayLike = Union[List[int], mx.array]
|
||||
|
||||
|
||||
class ProteinTokenizer:
|
||||
"""
|
||||
Protein sequence tokenizer compatible with ESM-2.
|
||||
|
||||
This class converts protein sequences into token IDs and back, following
|
||||
the vocabulary, special tokens, and formatting rules used by ESM-2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file: Optional[str] = None,
|
||||
special_tokens_map_file: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the ProteinTokenizer.
|
||||
|
||||
Args:
|
||||
vocab_file: Optional path to a file containing the vocabulary,
|
||||
one token per line.
|
||||
special_tokens_map_file: Optional path to a JSON file defining
|
||||
special token names and values.
|
||||
|
||||
If both files are provided, they override the default vocabulary and
|
||||
special token mappings. Otherwise, defaults are loaded.
|
||||
"""
|
||||
|
||||
# Load vocabulary from files if given, otherwise use built-in defaults
|
||||
if vocab_file and special_tokens_map_file:
|
||||
self._load_from_files(vocab_file, special_tokens_map_file)
|
||||
else:
|
||||
self._load_default_vocab()
|
||||
|
||||
# Create token ↔ ID mappings
|
||||
self.token_to_id = {tok: i for i, tok in enumerate(self.vocab)}
|
||||
self.id_to_token = {i: tok for i, tok in enumerate(self.vocab)}
|
||||
|
||||
# Cache special token IDs
|
||||
self.cls_id = self.token_to_id["<cls>"]
|
||||
self.pad_id = self.token_to_id["<pad>"]
|
||||
self.eos_id = self.token_to_id["<eos>"]
|
||||
self.unk_id = self.token_to_id["<unk>"]
|
||||
self.mask_id = self.token_to_id["<mask>"]
|
||||
|
||||
# Behavior flags for ESM-2-style BOS/EOS
|
||||
self.prepend_bos = True
|
||||
self.append_eos = True
|
||||
|
||||
def _load_from_files(self, vocab_file: str, special_tokens_map_file: str) -> None:
|
||||
"""Load vocabulary and special tokens from the provided files."""
|
||||
# Vocabulary file: one token per line
|
||||
vocab_path = Path(vocab_file)
|
||||
with open(vocab_path, "r", encoding="utf-8") as f:
|
||||
self.vocab = [line.strip() for line in f if line.strip()]
|
||||
|
||||
# Special tokens mapping file (JSON)
|
||||
special_tokens_path = Path(special_tokens_map_file)
|
||||
with open(special_tokens_path, "r", encoding="utf-8") as f:
|
||||
self.special_tokens_map = json.load(f)
|
||||
|
||||
def _load_default_vocab(self) -> None:
|
||||
"""Load the built-in ESM vocabulary and special token mapping."""
|
||||
# ESM convention: prepend special tokens, then amino acids, then <mask>
|
||||
prepend_toks = ["<cls>", "<pad>", "<eos>", "<unk>"]
|
||||
append_toks = ["<mask>"]
|
||||
|
||||
self.vocab = prepend_toks + PROTEIN_TOKENS
|
||||
|
||||
# Pad vocab size to multiple of 8 (original implementation detail)
|
||||
while len(self.vocab) % 8 != 0:
|
||||
self.vocab.append(f"<null_{len(self.vocab) - len(prepend_toks)}>")
|
||||
|
||||
self.vocab.extend(append_toks)
|
||||
|
||||
# Default special tokens map
|
||||
self.special_tokens_map = {
|
||||
"cls_token": "<cls>",
|
||||
"pad_token": "<pad>",
|
||||
"eos_token": "<eos>",
|
||||
"unk_token": "<unk>",
|
||||
"mask_token": "<mask>",
|
||||
}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sequence: str,
|
||||
*,
|
||||
add_special_tokens: bool = True,
|
||||
return_batch_dim: bool = False,
|
||||
dtype=mx.int32,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Convert a protein sequence into token IDs.
|
||||
|
||||
Args:
|
||||
sequence: Protein sequence (case-insensitive).
|
||||
add_special_tokens: If True, add <cls> at the start and <eos> at the end.
|
||||
return_batch_dim: If True, output shape will be (1, L) instead of (L,).
|
||||
dtype: MLX dtype for the returned array.
|
||||
|
||||
Returns:
|
||||
mx.array: Token IDs of shape (L,) or (1, L).
|
||||
"""
|
||||
ids: List[int] = []
|
||||
|
||||
if add_special_tokens and self.prepend_bos:
|
||||
ids.append(self.cls_id)
|
||||
|
||||
# Map each residue to its ID (defaulting to <unk> if not in vocab)
|
||||
for ch in sequence.upper():
|
||||
ids.append(self.token_to_id.get(ch, self.unk_id))
|
||||
|
||||
if add_special_tokens and self.append_eos:
|
||||
ids.append(self.eos_id)
|
||||
|
||||
arr = mx.array(ids, dtype=dtype)
|
||||
return mx.expand_dims(arr, axis=0) if return_batch_dim else arr
|
||||
|
||||
def batch_encode(
|
||||
self,
|
||||
sequences: Sequence[str],
|
||||
*,
|
||||
add_special_tokens: bool = True,
|
||||
max_length: Optional[int] = None,
|
||||
dtype=mx.int32,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Encode multiple protein sequences into a padded batch.
|
||||
|
||||
Args:
|
||||
sequences: List/sequence of protein strings.
|
||||
add_special_tokens: If True, add <cls> and <eos> tokens.
|
||||
max_length: If provided, truncate sequences to this length before padding.
|
||||
dtype: MLX dtype for the returned array.
|
||||
|
||||
Returns:
|
||||
mx.array: Tensor of shape (B, L) with right-padding using <pad> IDs.
|
||||
"""
|
||||
# Encode each sequence as (L,)
|
||||
encoded = [
|
||||
self.encode(s, add_special_tokens=add_special_tokens, dtype=dtype)
|
||||
for s in sequences
|
||||
]
|
||||
encoded = [e if e.ndim == 1 else e[0] for e in encoded]
|
||||
|
||||
if max_length is not None:
|
||||
encoded = [e[:max_length] for e in encoded]
|
||||
|
||||
# Find the longest sequence and right-pad all others
|
||||
max_len = max((int(e.shape[0]) for e in encoded), default=0)
|
||||
padded = []
|
||||
for e in encoded:
|
||||
pad_len = max_len - int(e.shape[0])
|
||||
if pad_len > 0:
|
||||
pad = mx.full((pad_len,), self.pad_id, dtype=dtype)
|
||||
e = mx.concatenate([e, pad], axis=0)
|
||||
padded.append(e)
|
||||
|
||||
return mx.stack(padded, axis=0) if padded else mx.array([], dtype=dtype)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids: ArrayLike,
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Convert token IDs back into a protein sequence string.
|
||||
|
||||
Args:
|
||||
token_ids: 1-D or 2-D array/list of IDs. If 2-D, only the first row is decoded.
|
||||
skip_special_tokens: If True, remove all special tokens from output.
|
||||
|
||||
Returns:
|
||||
str: Protein sequence.
|
||||
"""
|
||||
# Normalize to a 1-D MLX array
|
||||
if hasattr(token_ids, "shape") and hasattr(token_ids, "tolist"):
|
||||
ids = token_ids if token_ids.ndim == 1 else token_ids[0]
|
||||
else:
|
||||
ids = mx.array(token_ids, dtype=mx.int32)
|
||||
|
||||
ids_list = [int(x) for x in ids.tolist()]
|
||||
toks: List[str] = []
|
||||
|
||||
for i in ids_list:
|
||||
tok = self.id_to_token.get(i, "<unk>")
|
||||
if skip_special_tokens and tok in {
|
||||
"<cls>",
|
||||
"<pad>",
|
||||
"<eos>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
}:
|
||||
continue
|
||||
toks.append(tok)
|
||||
|
||||
return "".join(toks)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the size of the tokenizer’s vocabulary."""
|
||||
return len(self.vocab)
|
||||
81
esm/main.py
Normal file
81
esm/main.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ESM-2 MLX Inference")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
default="checkpoints/mlx-esm2_t33_650M_UR50D",
|
||||
help="Path to MLX model checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sequence",
|
||||
default="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
help="Protein sequence to test (default: human insulin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-position",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Position to mask (default: middle of sequence)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load pretrained ESM-2 model and tokenizer
|
||||
tokenizer, model = ESM2.from_pretrained(args.model_path)
|
||||
|
||||
# Determine sequence and position to mask
|
||||
sequence = args.sequence.upper()
|
||||
mask_pos = (
|
||||
args.mask_position if args.mask_position is not None else len(sequence) // 2
|
||||
)
|
||||
if mask_pos >= len(sequence):
|
||||
mask_pos = len(sequence) - 1
|
||||
original_aa = sequence[mask_pos] # The original amino acid at masked position
|
||||
|
||||
# Display input info
|
||||
print(f"Original sequence: {sequence}")
|
||||
print(f"Masked sequence: {sequence[:mask_pos]}<mask>{sequence[mask_pos+1:]}")
|
||||
print(f"Predicting position {mask_pos}: {original_aa}\n")
|
||||
|
||||
# Tokenize sequence before and after the mask
|
||||
before = tokenizer.encode(sequence[:mask_pos], add_special_tokens=False)
|
||||
after = tokenizer.encode(sequence[mask_pos + 1 :], add_special_tokens=False)
|
||||
|
||||
# Build token sequence with <cls>, <mask>, and <eos>
|
||||
tokens = mx.array(
|
||||
[
|
||||
[tokenizer.cls_id]
|
||||
+ before.tolist()
|
||||
+ [tokenizer.mask_id]
|
||||
+ after.tolist()
|
||||
+ [tokenizer.eos_id]
|
||||
]
|
||||
)
|
||||
mask_token_pos = 1 + len(before) # Position of <mask> token
|
||||
|
||||
# Run model to get logits for each token position
|
||||
logits = model(tokens)["logits"]
|
||||
probs = mx.softmax(
|
||||
logits[0, mask_token_pos, :]
|
||||
) # Softmax over vocabulary at mask position
|
||||
|
||||
# Get top-5 most likely tokens
|
||||
top_indices = mx.argsort(probs)[-5:][::-1]
|
||||
|
||||
# Print predictions
|
||||
print("Top predictions:")
|
||||
for i, idx in enumerate(top_indices):
|
||||
token = tokenizer.vocab[int(idx)]
|
||||
if token in tokenizer.vocab:
|
||||
prob = float(probs[idx])
|
||||
marker = "✓" if token == original_aa else " "
|
||||
print(f"{marker} {i+1}. {token}: {prob:.3f} ({prob*100:.1f}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
602
esm/notebooks/contact_prediction.ipynb
Normal file
602
esm/notebooks/contact_prediction.ipynb
Normal file
@@ -0,0 +1,602 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3fbacbe4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Predicting Protein Contacts with ESM-2\n",
|
||||
"\n",
|
||||
"Understanding how amino acids interact within a folded protein is essential for grasping protein function and stability. Contact prediction, the task of identifying which residues are close together in three-dimensional space, is a key step in the sequence to structure process. ESM-2’s learned attention patterns capture evolutionary signals that encode structural information, which allows the model to predict residue contacts directly from sequence data.\n",
|
||||
"\n",
|
||||
"In this notebook, we'll explore ESM-2's ability to predict protein contacts across three diverse proteins from different organisms:\n",
|
||||
"\n",
|
||||
"**Bacterial Transport:**\n",
|
||||
"- **1a3a (PTS Mannitol Component)**: A phosphoenolpyruvate-dependent sugar phosphotransferase system component from *E. coli*, essential for carbohydrate metabolism\n",
|
||||
"\n",
|
||||
"**Stress Response:**\n",
|
||||
"- **5ahw (Universal Stress Protein)**: A conserved stress response protein from *Mycolicibacterium smegmatis* that helps cells survive harsh conditions\n",
|
||||
"\n",
|
||||
"**Human Metabolism:**\n",
|
||||
"- **1xcr (Ester Hydrolase)**: A human enzyme (C11orf54) involved in lipid metabolism and cellular signaling pathways\n",
|
||||
"\n",
|
||||
"We will evaluate how effectively ESM-2 captures the structural relationships present in these sequences, measuring precision across different sequence separation ranges to assess both local and long-range contact prediction performance. This notebook is a modified version of a [notebook by the same name](https://github.com/facebookresearch/esm/blob/main/examples/contact_prediction.ipynb) from the [offical ESM repsitory](https://github.com/facebookresearch/esm)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "08352b12",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"\n",
|
||||
"Here we import all neccessary libraries."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c1047c94",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31mRunning cells with '.venv (Python 3.11.13)' requires the ipykernel package.\n",
|
||||
"\u001b[1;31mInstall 'ipykernel' into the Python environment. \n",
|
||||
"\u001b[1;31mCommand: '/Users/vincent/Developer/mlx-examples/.venv/bin/python -m pip install ipykernel -U --force-reinstall'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import List, Tuple, Optional, Dict\n",
|
||||
"import string\n",
|
||||
"\n",
|
||||
"import mlx.core as mx\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from scipy.spatial.distance import squareform, pdist\n",
|
||||
"import biotite.structure as bs\n",
|
||||
"from biotite.database import rcsb\n",
|
||||
"from biotite.structure.io.pdbx import CIFFile, get_structure\n",
|
||||
"from Bio import SeqIO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5f0af076",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Download multiple sequence alignment (MSA) files for our three test proteins from the ESM repository."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3264b66d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!mkdir -p data\n",
|
||||
"!curl -o data/1a3a_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1a3a_1_A.a3m\n",
|
||||
"!curl -o data/5ahw_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/5ahw_1_A.a3m\n",
|
||||
"!curl -o data/1xcr_1_A.a3m https://raw.githubusercontent.com/facebookresearch/esm/main/examples/data/1xcr_1_A.a3m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cbf1d0cb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Loading the model\n",
|
||||
"\n",
|
||||
"Load the ESM-2 model. Here we will use the 650M parameter version. Change the path below to point to your converted checkpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4406e8a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append(\"..\")\n",
|
||||
"\n",
|
||||
"from esm import ESM2\n",
|
||||
"\n",
|
||||
"esm_checkpoint = \"../checkpoints/mlx-esm2_t33_650M_UR50D\"\n",
|
||||
"tokenizer, model = ESM2.from_pretrained(esm_checkpoint)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "77596456",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Defining functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eb5f07ed",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Parsing alignments"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e754abd7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This function parses multiple sequence alignment files and clean up insertion artifacts. MSA files often contain lowercase letters and special characters (`.`, `*`) to indicate insertions relative to the reference sequence - these need to be removed to get the core aligned sequences."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "43717bea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"deletekeys = dict.fromkeys(string.ascii_lowercase)\n",
|
||||
"deletekeys[\".\"] = None\n",
|
||||
"deletekeys[\"*\"] = None\n",
|
||||
"translation = str.maketrans(deletekeys)\n",
|
||||
"\n",
|
||||
"def read_sequence(filename: str) -> Tuple[str, str]:\n",
|
||||
" \"\"\" Reads the first (reference) sequences from a fasta or MSA file.\"\"\"\n",
|
||||
" record = next(SeqIO.parse(filename, \"fasta\"))\n",
|
||||
" return record.description, str(record.seq)\n",
|
||||
"\n",
|
||||
"def remove_insertions(sequence: str) -> str:\n",
|
||||
" \"\"\" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. \"\"\"\n",
|
||||
" return sequence.translate(translation)\n",
|
||||
"\n",
|
||||
"def read_msa(filename: str) -> List[Tuple[str, str]]:\n",
|
||||
" \"\"\" Reads the sequences from an MSA file, automatically removes insertions.\"\"\"\n",
|
||||
" return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, \"fasta\")]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "628d7de1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Converting structures to contacts\n",
|
||||
"\n",
|
||||
"There are many ways to define a protein contact. Here we're using the definition of 8 angstroms between carbon beta atoms. Note that the position of the carbon beta is imputed from the position of the N, CA, and C atoms for each residue."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "21e0b44b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def extend(a, b, c, L, A, D):\n",
|
||||
" \"\"\"\n",
|
||||
" input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral\n",
|
||||
" output: 4th coord\n",
|
||||
" \"\"\"\n",
|
||||
" def normalize(x):\n",
|
||||
" return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)\n",
|
||||
"\n",
|
||||
" bc = normalize(b - c)\n",
|
||||
" n = normalize(np.cross(b - a, bc))\n",
|
||||
" m = [bc, np.cross(n, bc), n]\n",
|
||||
" d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]\n",
|
||||
" return c + sum([m * d for m, d in zip(m, d)])\n",
|
||||
"\n",
|
||||
"def contacts_from_pdb(\n",
|
||||
" structure: bs.AtomArray,\n",
|
||||
" distance_threshold: float = 8.0,\n",
|
||||
" chain: Optional[str] = None,\n",
|
||||
") -> np.ndarray:\n",
|
||||
" \"\"\"Extract contacts from PDB structure.\"\"\"\n",
|
||||
" mask = ~structure.hetero\n",
|
||||
" if chain is not None:\n",
|
||||
" mask &= structure.chain_id == chain\n",
|
||||
"\n",
|
||||
" N = structure.coord[mask & (structure.atom_name == \"N\")]\n",
|
||||
" CA = structure.coord[mask & (structure.atom_name == \"CA\")]\n",
|
||||
" C = structure.coord[mask & (structure.atom_name == \"C\")]\n",
|
||||
"\n",
|
||||
" Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)\n",
|
||||
" dist = squareform(pdist(Cbeta))\n",
|
||||
" \n",
|
||||
" contacts = dist < distance_threshold\n",
|
||||
" contacts = contacts.astype(np.int64)\n",
|
||||
" contacts[np.isnan(dist)] = -1\n",
|
||||
" return contacts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5473f306",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Computing contact precisions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e361a9f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calculate precision metrics to evaluate contact prediction quality. The `compute_precisions` function ranks predicted contacts by confidence and measures how many of the top predictions are true contacts, while `evaluate_prediction` breaks this down by sequence separation ranges (local, short, medium, long-range) since predicting distant contacts is typically much harder than nearby ones."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62c37bbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compute_precisions(\n",
|
||||
" predictions: mx.array,\n",
|
||||
" targets: mx.array,\n",
|
||||
" minsep: int = 6,\n",
|
||||
" maxsep: Optional[int] = None,\n",
|
||||
" override_length: Optional[int] = None,\n",
|
||||
") -> Dict[str, mx.array]:\n",
|
||||
" \"\"\"Compute precision metrics for contact prediction.\"\"\"\n",
|
||||
" batch_size, seqlen, _ = predictions.shape\n",
|
||||
" \n",
|
||||
" if maxsep is not None:\n",
|
||||
" sep_mask_2d = mx.abs(mx.arange(seqlen)[None, :] - mx.arange(seqlen)[:, None]) <= maxsep\n",
|
||||
" targets = targets * sep_mask_2d[None, :]\n",
|
||||
" \n",
|
||||
" targets = targets.astype(mx.float32)\n",
|
||||
" src_lengths = (targets >= 0).sum(axis=-1).sum(axis=-1).astype(mx.float32)\n",
|
||||
" \n",
|
||||
" x_ind, y_ind = [], []\n",
|
||||
" for i in range(seqlen):\n",
|
||||
" for j in range(i + minsep, seqlen):\n",
|
||||
" x_ind.append(i)\n",
|
||||
" y_ind.append(j)\n",
|
||||
" \n",
|
||||
" x_ind = mx.array(x_ind)\n",
|
||||
" y_ind = mx.array(y_ind)\n",
|
||||
" \n",
|
||||
" predictions_upper = predictions[:, x_ind, y_ind]\n",
|
||||
" targets_upper = targets[:, x_ind, y_ind]\n",
|
||||
"\n",
|
||||
" topk = seqlen if override_length is None else max(seqlen, override_length)\n",
|
||||
" indices = mx.argsort(predictions_upper, axis=-1)[:, ::-1][:, :topk]\n",
|
||||
" \n",
|
||||
" batch_indices = mx.arange(batch_size)[:, None]\n",
|
||||
" topk_targets = targets_upper[batch_indices, indices]\n",
|
||||
" \n",
|
||||
" if topk_targets.shape[1] < topk:\n",
|
||||
" pad_shape = (topk_targets.shape[0], topk - topk_targets.shape[1])\n",
|
||||
" padding = mx.zeros(pad_shape)\n",
|
||||
" topk_targets = mx.concatenate([topk_targets, padding], 1)\n",
|
||||
"\n",
|
||||
" cumulative_dist = mx.cumsum(topk_targets, -1)\n",
|
||||
"\n",
|
||||
" gather_lengths = src_lengths[:, None]\n",
|
||||
" if override_length is not None:\n",
|
||||
" gather_lengths = override_length * mx.ones_like(gather_lengths)\n",
|
||||
"\n",
|
||||
" precision_fractions = mx.arange(0.1, 1.1, 0.1)\n",
|
||||
" gather_indices = (precision_fractions[None, :] * gather_lengths) - 1\n",
|
||||
" gather_indices = mx.clip(gather_indices, 0, cumulative_dist.shape[1] - 1)\n",
|
||||
" gather_indices = gather_indices.astype(mx.int32)\n",
|
||||
"\n",
|
||||
" binned_cumulative_dist = cumulative_dist[batch_indices, gather_indices]\n",
|
||||
" binned_precisions = binned_cumulative_dist / (gather_indices + 1)\n",
|
||||
"\n",
|
||||
" pl5 = binned_precisions[:, 1]\n",
|
||||
" pl2 = binned_precisions[:, 4]\n",
|
||||
" pl = binned_precisions[:, 9]\n",
|
||||
" auc = binned_precisions.mean(-1)\n",
|
||||
"\n",
|
||||
" return {\"AUC\": auc, \"P@L\": pl, \"P@L2\": pl2, \"P@L5\": pl5}\n",
|
||||
"\n",
|
||||
"def evaluate_prediction(\n",
|
||||
" predictions: mx.array,\n",
|
||||
" targets: mx.array,\n",
|
||||
") -> Dict[str, float]:\n",
|
||||
" \"\"\"Evaluate contact predictions across different sequence separation ranges.\"\"\"\n",
|
||||
" contact_ranges = [\n",
|
||||
" (\"local\", 3, 6),\n",
|
||||
" (\"short\", 6, 12),\n",
|
||||
" (\"medium\", 12, 24),\n",
|
||||
" (\"long\", 24, None),\n",
|
||||
" ]\n",
|
||||
" metrics = {}\n",
|
||||
" \n",
|
||||
" for name, minsep, maxsep in contact_ranges:\n",
|
||||
" rangemetrics = compute_precisions(\n",
|
||||
" predictions,\n",
|
||||
" targets,\n",
|
||||
" minsep=minsep,\n",
|
||||
" maxsep=maxsep,\n",
|
||||
" )\n",
|
||||
" for key, val in rangemetrics.items():\n",
|
||||
" metrics[f\"{name}_{key}\"] = float(val[0])\n",
|
||||
" return metrics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5873e052",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Predicting contacts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d5778a9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This function wraps the tokenization and model inference steps, converting a raw amino acid sequence into token IDs and passing them through ESM-2's contact prediction head to produce a contact probability matrix."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dddf31a7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict_contacts(sequence: str, model, tokenizer) -> mx.array:\n",
|
||||
" \"\"\" Predict contacts for a given sequence \"\"\"\n",
|
||||
" tokens = tokenizer.encode(sequence)\n",
|
||||
" contacts = model.predict_contacts(tokens)\n",
|
||||
" return contacts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "62562401",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Plotting results\n",
|
||||
"\n",
|
||||
"This function visualizes contacts as a symmetric matrix where both axes index residue positions. The lower triangle shows the model’s confidence as a blue heatmap, with darker cells indicating higher confidence. The upper triangle overlays evaluation markers: blue dots are correctly predicted contacts (true positives), red dots are predicted but not real (false positives), and grey dots are real contacts the model missed (false negatives)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03e03791",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_contacts_and_predictions(\n",
|
||||
" predictions: mx.array,\n",
|
||||
" contacts: np.ndarray,\n",
|
||||
" ax,\n",
|
||||
" title: str,\n",
|
||||
" cmap: str = \"Blues\",\n",
|
||||
" ms: float = 1,\n",
|
||||
"):\n",
|
||||
" \"\"\"Plot contact predictions and true contacts.\"\"\"\n",
|
||||
" if isinstance(predictions, mx.array):\n",
|
||||
" predictions = np.array(predictions)\n",
|
||||
" \n",
|
||||
" seqlen = contacts.shape[0]\n",
|
||||
" relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))\n",
|
||||
" bottom_mask = relative_distance < 0\n",
|
||||
" masked_image = np.ma.masked_where(bottom_mask, predictions)\n",
|
||||
" invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6\n",
|
||||
" predictions_copy = predictions.copy()\n",
|
||||
" predictions_copy[invalid_mask] = float(\"-inf\")\n",
|
||||
"\n",
|
||||
" topl_val = np.sort(predictions_copy.reshape(-1))[-seqlen]\n",
|
||||
" pred_contacts = predictions_copy >= topl_val\n",
|
||||
" true_positives = contacts & pred_contacts & ~bottom_mask\n",
|
||||
" false_positives = ~contacts & pred_contacts & ~bottom_mask\n",
|
||||
" other_contacts = contacts & ~pred_contacts & ~bottom_mask\n",
|
||||
"\n",
|
||||
" ax.imshow(masked_image, cmap=cmap)\n",
|
||||
" ax.plot(*np.where(other_contacts), \"o\", c=\"grey\", ms=ms)\n",
|
||||
" ax.plot(*np.where(false_positives), \"o\", c=\"r\", ms=ms)\n",
|
||||
" ax.plot(*np.where(true_positives), \"o\", c=\"b\", ms=ms)\n",
|
||||
" ax.set_title(title)\n",
|
||||
" ax.axis(\"square\")\n",
|
||||
" ax.set_xlim([0, seqlen])\n",
|
||||
" ax.set_ylim([0, seqlen])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9364c984",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Predict and visualize\n",
|
||||
"Here we'll use ESM-2 contact prediction on our three test proteins and evaluate the results. We'll compute precision metrics across different sequence separation ranges and create contact maps that visualize both the model's predictions and how well they match the true protein structures."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9fa9e59e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Read Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7da50dc2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Load experimental protein structures from the Protein Data Bank and extract true contact maps for evaluation, while also parsing the reference sequences from our MSA files that will serve as input to ESM-2."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2d276137",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PDB_IDS = [\"1a3a\", \"5ahw\", \"1xcr\"]\n",
|
||||
"\n",
|
||||
"structures = {\n",
|
||||
" name.lower(): get_structure(CIFFile.read(rcsb.fetch(name, \"cif\")))[0]\n",
|
||||
" for name in PDB_IDS\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"contacts = {\n",
|
||||
" name: contacts_from_pdb(structure, chain=\"A\") \n",
|
||||
" for name, structure in structures.items()\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"msas = {\n",
|
||||
" name: read_msa(f\"data/{name.lower()}_1_A.a3m\")\n",
|
||||
" for name in PDB_IDS\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"sequences = {\n",
|
||||
" name: msa[0] for name, msa in msas.items()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4ce64f18",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### ESM-2 predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f2da88f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Evaluate predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0adb0a11",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This loop generates contact predictions for each protein using ESM-2, compares them against the experimentally determined structures, and computes precision metrics across different sequence separation ranges to evaluate model performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "941b4afa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictions = {}\n",
|
||||
"results = []\n",
|
||||
"\n",
|
||||
"for pdb_id in sequences:\n",
|
||||
" _, sequence = sequences[pdb_id]\n",
|
||||
" prediction = predict_contacts(sequence, model, tokenizer)\n",
|
||||
" predictions[pdb_id] = prediction[0]\n",
|
||||
" \n",
|
||||
" true_contacts = mx.array(contacts[pdb_id])\n",
|
||||
" \n",
|
||||
" min_len = min(prediction.shape[1], true_contacts.shape[0])\n",
|
||||
" pred_trimmed = prediction[:, :min_len, :min_len]\n",
|
||||
" true_trimmed = true_contacts[:min_len, :min_len]\n",
|
||||
" true_trimmed = mx.expand_dims(true_trimmed, axis=0)\n",
|
||||
" \n",
|
||||
" metrics = evaluate_prediction(pred_trimmed, true_trimmed)\n",
|
||||
" result = {\"id\": pdb_id, \"model\": \"ESM-2 (Unsupervised)\"}\n",
|
||||
" result.update(metrics)\n",
|
||||
" results.append(result)\n",
|
||||
"\n",
|
||||
"results_df = pd.DataFrame(results)\n",
|
||||
"display(results_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c5c7418a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The results demonstrate that ESM-2 excels at predicting long-range contacts, with precision scores ranging from 40.9% to 86.4% for residues more than 24 positions apart. Performance is consistently higher for distant contacts compared to local ones. For example, the universal stress protein (5ahw) achieves 86.4% precision for long-range contacts but only 2.4% for local contacts between 3 and 6 residues apart. This trend is observed across all three proteins, with medium-range contacts (12–24 residues apart) and short-range contacts (6–12 residues apart) showing intermediate accuracy. These results suggest that ESM-2 has learned to identify evolutionarily conserved structural motifs that connect distant regions of the sequence, which are often critical for protein fold stability and function."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "487cff51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### Plot contacts and predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "10291191",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This analysis generates contact map visualizations for all three proteins, presenting ESM-2’s predictions as heatmaps and overlaying the true experimental contacts as colored dots."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "628efc10",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"proteins = [r['id'] for r in results]\n",
|
||||
"fig, axes = plt.subplots(figsize=(6 * len(proteins), 6), ncols=len(proteins))\n",
|
||||
"if len(proteins) == 1:\n",
|
||||
" axes = [axes]\n",
|
||||
"\n",
|
||||
"for ax, pdb_id in zip(axes, proteins):\n",
|
||||
" prediction = predictions[pdb_id]\n",
|
||||
" target = contacts[pdb_id]\n",
|
||||
" \n",
|
||||
" result = next(r for r in results if r['id'] == pdb_id)\n",
|
||||
" long_pl = result['long_P@L']\n",
|
||||
" \n",
|
||||
" plot_contacts_and_predictions(\n",
|
||||
" prediction, target, ax=ax, \n",
|
||||
" title=f\"{pdb_id}: Long Range P@L: {100 * long_pl:.1f}%\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "99e1edaf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The contact maps highlight ESM-2’s strong ability to detect long-range structural relationships. In each panel, the lower triangle shows model predictions, where darker blue regions indicate high-confidence contacts, and the upper triangle shows the corresponding experimental data. Correct predictions appear as blue dots, forming distinct off-diagonal patterns in 5ahw and 1a3a that capture key global fold interactions. Red dots mark false positives, which are relatively rare, while grey dots represent missed contacts. These missed contacts are notably more frequent in 1xcr, consistent with its lower long-range precision. The dense clusters of blue true positives in 5ahw, compared to the sparser, fragmented patterns in 1xcr, clearly illustrate the variation in predictive performance across proteins."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
334
esm/notebooks/embeddings.ipynb
Normal file
334
esm/notebooks/embeddings.ipynb
Normal file
File diff suppressed because one or more lines are too long
662
esm/notebooks/mutation_effect_prediction.ipynb
Normal file
662
esm/notebooks/mutation_effect_prediction.ipynb
Normal file
File diff suppressed because one or more lines are too long
12
esm/requirements.txt
Normal file
12
esm/requirements.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
mlx
|
||||
torch
|
||||
transformers
|
||||
numpy
|
||||
pandas
|
||||
seaborn
|
||||
biopython
|
||||
biotite
|
||||
scipy
|
||||
tqdm
|
||||
scikit-learn
|
||||
matplotlib
|
||||
121
esm/test.py
Normal file
121
esm/test.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, EsmConfig, EsmForMaskedLM
|
||||
|
||||
from esm import ESM2
|
||||
|
||||
# Paths for MLX and Hugging Face versions of ESM-2
|
||||
MLX_PATH = "checkpoints/mlx-esm2_t12_35M_UR50D"
|
||||
HF_PATH = "facebook/esm2_t12_35M_UR50D"
|
||||
|
||||
|
||||
def load_mlx_model():
|
||||
"""Load MLX ESM-2 model and tokenizer."""
|
||||
tokenizer, model = ESM2.from_pretrained(MLX_PATH)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_hf_model():
|
||||
"""Load Hugging Face ESM-2 model and tokenizer with hidden states + attentions."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_PATH)
|
||||
config = EsmConfig.from_pretrained(
|
||||
HF_PATH, output_hidden_states=True, output_attentions=True
|
||||
)
|
||||
model = EsmForMaskedLM.from_pretrained(HF_PATH, config=config)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
class TestESM2(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Load both MLX and HF models/tokenizers once for all tests
|
||||
cls.mlx_tokenizer, cls.mlx_model = load_mlx_model()
|
||||
cls.hf_tokenizer, cls.hf_model = load_hf_model()
|
||||
|
||||
def test_tokenizer(self):
|
||||
"""Verify MLX tokenizer matches Hugging Face tokenizer behavior."""
|
||||
self.assertEqual(len(self.mlx_tokenizer), len(self.hf_tokenizer))
|
||||
|
||||
sequences = [
|
||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
]
|
||||
|
||||
# Compare batched tokenization (padded sequences)
|
||||
mlx_batch = self.mlx_tokenizer.batch_encode(sequences)
|
||||
hf_batch = (
|
||||
self.hf_tokenizer(sequences, return_tensors="pt", padding=True)["input_ids"]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
self.assertEqual(tuple(mlx_batch.shape), tuple(hf_batch.shape))
|
||||
self.assertTrue(
|
||||
np.array_equal(np.array(mlx_batch.tolist(), dtype=hf_batch.dtype), hf_batch)
|
||||
)
|
||||
|
||||
# Compare single-sequence encode/decode
|
||||
for sequence in sequences:
|
||||
mlx_tokens = self.mlx_tokenizer.encode(sequence)
|
||||
hf_tokens = (
|
||||
self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist()[0]
|
||||
)
|
||||
self.assertTrue(np.array_equal(mlx_tokens, hf_tokens))
|
||||
self.assertEqual(
|
||||
self.mlx_tokenizer.decode(mlx_tokens),
|
||||
self.hf_tokenizer.decode(hf_tokens).replace(" ", ""),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.mlx_tokenizer.decode(mlx_tokens, skip_special_tokens=True),
|
||||
self.hf_tokenizer.decode(hf_tokens, skip_special_tokens=True).replace(
|
||||
" ", ""
|
||||
),
|
||||
)
|
||||
|
||||
def test_model(self):
|
||||
"""Verify MLX and HF model outputs match (logits, hidden states, attentions)."""
|
||||
sequences = [
|
||||
"MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
|
||||
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN",
|
||||
]
|
||||
for sequence in sequences:
|
||||
# Tokenize
|
||||
mlx_tokens = self.mlx_tokenizer.encode(sequence, return_batch_dim=True)
|
||||
hf_tokens = self.hf_tokenizer(sequence, return_tensors="pt")["input_ids"]
|
||||
|
||||
# Forward pass
|
||||
mlx_outputs = self.mlx_model(
|
||||
mlx_tokens,
|
||||
repr_layers=[self.mlx_model.num_layers],
|
||||
need_head_weights=True,
|
||||
)
|
||||
hf_outputs = self.hf_model(input_ids=hf_tokens)
|
||||
|
||||
# Compare logits
|
||||
mlx_logits = np.array(mlx_outputs["logits"])
|
||||
hf_logits = hf_outputs["logits"].detach().cpu().numpy()
|
||||
self.assertTrue(np.allclose(mlx_logits, hf_logits, rtol=1e-4, atol=1e-4))
|
||||
|
||||
# Compare final-layer hidden states
|
||||
final_layer = self.mlx_model.num_layers
|
||||
mlx_hidden_states = np.array(mlx_outputs["representations"][final_layer])
|
||||
hf_hidden_states = hf_outputs["hidden_states"][-1].detach().cpu().numpy()
|
||||
self.assertTrue(
|
||||
np.allclose(mlx_hidden_states, hf_hidden_states, rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
# Compare attentions for final layer
|
||||
mlx_attentions = np.array(
|
||||
mlx_outputs["attentions"][:, final_layer - 1, :, :, :]
|
||||
)
|
||||
hf_attentions = hf_outputs["attentions"][-1].detach().cpu().numpy()
|
||||
self.assertTrue(
|
||||
np.allclose(mlx_attentions, hf_attentions, rtol=1e-4, atol=1e-4)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -167,8 +167,9 @@ python dreambooth.py \
|
||||
path/to/dreambooth/dataset/dog6
|
||||
```
|
||||
|
||||
|
||||
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
|
||||
Or you can directly use the pre-processed Hugging Face dataset
|
||||
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
|
||||
for fine-tuning.
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
@@ -210,3 +211,71 @@ speed during generation.
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
|
||||
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
|
||||
|
||||
|
||||
Distributed Computation
|
||||
------------------------
|
||||
|
||||
The FLUX example supports distributed computation during both generation and
|
||||
training. See the [distributed communication
|
||||
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
|
||||
for information on how to set-up MLX for distributed communication. The rest of
|
||||
this section assumes you can launch distributed MLX programs using `mlx.launch
|
||||
--hostfile hostfile.json`.
|
||||
|
||||
### Distributed Finetuning
|
||||
|
||||
Distributed finetuning scales very well with FLUX and all one has to do is
|
||||
adjust the gradient accumulation and training iterations so that the batch
|
||||
size remains the same. For instance, to replicate the following training
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 8 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
On 4 machines we simply run
|
||||
|
||||
```shell
|
||||
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 2 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
|
||||
|
||||
### Distributed Inference
|
||||
|
||||
Distributed inference can be divided in two different approaches. The first
|
||||
approach is the data-parallel approach, where each node generates its own
|
||||
images and shares them at the end. The second approach is the model-parallel
|
||||
approach where the model is shared across the nodes and they collaboratively
|
||||
generate the images.
|
||||
|
||||
The `txt2image.py` script will attempt to choose the best approach depending on
|
||||
how many images are being generated across the nodes. The model-parallel
|
||||
approach can be forced by passing the argument `--force-shard`.
|
||||
|
||||
For better performance in the model-parallel approach we suggest that you use a
|
||||
[thunderbolt
|
||||
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
|
||||
|
||||
All you have to do once again is use `mlx.launch` as follows
|
||||
|
||||
```shell
|
||||
mlx.launch --verbose --hostfile hostfile.json -- \
|
||||
python txt2image.py --model schnell \
|
||||
--n-images 8 \
|
||||
--image-size 512x512 \
|
||||
--verbose \
|
||||
'A photo of an astronaut riding a horse on Mars'
|
||||
```
|
||||
|
||||
for model-parallel generation you may want to also pass `--env
|
||||
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
|
||||
reduces the CPU/GPU synchronization overhead.
|
||||
|
||||
@@ -178,6 +178,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.sharding_group = None
|
||||
|
||||
def __call__(
|
||||
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
@@ -216,18 +218,35 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = _attention(q, k, v, pe)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# Project - cat - average - split
|
||||
txt_attn = self.txt_attn.proj(txt_attn)
|
||||
img_attn = self.img_attn.proj(img_attn)
|
||||
if self.sharding_group is not None:
|
||||
attn = mx.concatenate([txt_attn, img_attn], axis=1)
|
||||
attn = mx.distributed.all_sum(attn, group=self.sharding_group)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp(
|
||||
img = img + img_mod1.gate * img_attn
|
||||
img_mlp = self.img_mlp(
|
||||
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
||||
)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp(
|
||||
txt = txt + txt_mod1.gate * txt_attn
|
||||
txt_mlp = self.txt_mlp(
|
||||
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
||||
)
|
||||
|
||||
if self.sharding_group is not None:
|
||||
txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
|
||||
txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
|
||||
txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)
|
||||
|
||||
# finalize the img/txt blocks
|
||||
img = img + img_mod2.gate * img_mlp
|
||||
txt = txt + txt_mod2.gate * txt_mlp
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -96,6 +97,47 @@ class Flux(nn.Module):
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
if N == 1:
|
||||
return
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.num_heads //= N
|
||||
block.img_attn.num_heads //= N
|
||||
block.txt_attn.num_heads //= N
|
||||
block.sharding_group = group
|
||||
block.img_attn.qkv = shard_linear(
|
||||
block.img_attn.qkv, "all-to-sharded", segments=3, group=group
|
||||
)
|
||||
block.txt_attn.qkv = shard_linear(
|
||||
block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
|
||||
)
|
||||
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
|
||||
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
|
||||
block.img_mlp.layers[0] = shard_linear(
|
||||
block.img_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
block.txt_mlp.layers[0] = shard_linear(
|
||||
block.txt_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
|
||||
shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.num_heads //= N
|
||||
block.hidden_size //= N
|
||||
block.linear1 = shard_linear(
|
||||
block.linear1,
|
||||
"all-to-sharded",
|
||||
segments=[1 / 7, 2 / 7, 3 / 7],
|
||||
group=group,
|
||||
)
|
||||
block.linear2 = shard_linear(
|
||||
block.linear2, "sharded-to-all", segments=[1 / 5], group=group
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img: mx.array,
|
||||
|
||||
109
flux/generate_interactive.py
Normal file
109
flux/generate_interactive.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
def print_zero(group, *args, **kwargs):
|
||||
if group.rank() == 0:
|
||||
flush = kwargs.pop("flush", True)
|
||||
print(*args, **kwargs, flush=flush)
|
||||
|
||||
|
||||
def quantization_predicate(name, m):
|
||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
||||
|
||||
|
||||
def to_latent_size(image_size):
|
||||
h, w = image_size
|
||||
h = ((h + 15) // 16) * 16
|
||||
w = ((w + 15) // 16) * 16
|
||||
|
||||
if (h, w) != image_size:
|
||||
print(
|
||||
"Warning: The image dimensions need to be divisible by 16px. "
|
||||
f"Changing size to {h}x{w}."
|
||||
)
|
||||
|
||||
return (h // 8, w // 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using FLUX"
|
||||
)
|
||||
parser.add_argument("--quantize", "-q", action="store_true")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
args = parser.parse_args()
|
||||
|
||||
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
|
||||
|
||||
if args.quantize:
|
||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
group = mx.distributed.init()
|
||||
if group.size() > 1:
|
||||
flux.flow.shard(group)
|
||||
|
||||
print_zero(group, "Loading models")
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
def print_help():
|
||||
print_zero(group, "The command list:")
|
||||
print_zero(group, "- 'q' to exit")
|
||||
print_zero(group, "- 's HxW' to change the size of the image")
|
||||
print_zero(group, "- 'n S' to change the number of steps")
|
||||
print_zero(group, "- 'h' to print this help")
|
||||
|
||||
print_zero(group, "FLUX interactive session")
|
||||
print_help()
|
||||
seed = 0
|
||||
size = (512, 512)
|
||||
latent_size = to_latent_size(size)
|
||||
steps = 50 if args.model == "dev" else 4
|
||||
while True:
|
||||
prompt = input(">> " if group.rank() == 0 else "")
|
||||
if prompt == "q":
|
||||
break
|
||||
if prompt == "h":
|
||||
print_help()
|
||||
continue
|
||||
if prompt.startswith("s "):
|
||||
size = tuple([int(xi) for xi in prompt[2:].split("x")])
|
||||
print_zero(group, "Setting the size to", size)
|
||||
latent_size = to_latent_size(size)
|
||||
continue
|
||||
if prompt.startswith("n "):
|
||||
steps = int(prompt[2:])
|
||||
print_zero(group, "Setting the steps to", steps)
|
||||
continue
|
||||
|
||||
seed += 1
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=1,
|
||||
num_steps=steps,
|
||||
latent_size=latent_size,
|
||||
guidance=4.0,
|
||||
seed=seed,
|
||||
)
|
||||
print_zero(group, "Processing prompt")
|
||||
mx.eval(next(latents))
|
||||
print_zero(group, "Generating latents")
|
||||
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
|
||||
mx.eval(xt)
|
||||
print_zero(group, "Generating image")
|
||||
xt = flux.decode(xt, latent_size)
|
||||
xt = (xt * 255).astype(mx.uint8)
|
||||
mx.eval(xt)
|
||||
im = Image.fromarray(np.array(xt[0]))
|
||||
im.save(args.output)
|
||||
print_zero(group, "Saved at", args.output, end="\n\n")
|
||||
@@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using stable diffusion"
|
||||
description="Generate images from a textual prompt using FLUX"
|
||||
)
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
@@ -62,6 +62,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--adapter")
|
||||
parser.add_argument("--fuse-adapter", action="store_true")
|
||||
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
||||
parser.add_argument("--force-shard", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the models
|
||||
@@ -76,6 +77,24 @@ if __name__ == "__main__":
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
# Figure out what kind of distributed generation we should do
|
||||
group = mx.distributed.init()
|
||||
n_images = args.n_images
|
||||
should_gather = False
|
||||
if group.size() > 1:
|
||||
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
||||
flux.flow.shard(group)
|
||||
else:
|
||||
n_images //= group.size()
|
||||
should_gather = True
|
||||
|
||||
# If we are sharding we should have the same seed and if we are doing
|
||||
# data parallel generation we should have different seeds
|
||||
if args.seed is None:
|
||||
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
||||
if should_gather:
|
||||
args.seed = args.seed + group.rank()
|
||||
|
||||
if args.preload_models:
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
@@ -83,7 +102,7 @@ if __name__ == "__main__":
|
||||
latent_size = to_latent_size(args.image_size)
|
||||
latents = flux.generate_latents(
|
||||
args.prompt,
|
||||
n_images=args.n_images,
|
||||
n_images=n_images,
|
||||
num_steps=args.steps,
|
||||
latent_size=latent_size,
|
||||
guidance=args.guidance,
|
||||
@@ -93,8 +112,8 @@ if __name__ == "__main__":
|
||||
# First we get and eval the conditioning
|
||||
conditioning = next(latents)
|
||||
mx.eval(conditioning)
|
||||
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
|
||||
mx.metal.reset_peak_memory()
|
||||
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
|
||||
mx.reset_peak_memory()
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the text encoders.
|
||||
@@ -102,36 +121,42 @@ if __name__ == "__main__":
|
||||
del flux.clip
|
||||
|
||||
# Actual denoising loop
|
||||
for x_t in tqdm(latents, total=args.steps):
|
||||
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
|
||||
mx.eval(x_t)
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the flow transformer.
|
||||
del flux.flow
|
||||
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
|
||||
mx.metal.reset_peak_memory()
|
||||
peak_mem_generation = mx.get_peak_memory() / 1024**3
|
||||
mx.reset_peak_memory()
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
|
||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
||||
mx.eval(decoded[-1])
|
||||
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
|
||||
peak_mem_decoding = mx.get_peak_memory() / 1024**3
|
||||
peak_mem_overall = max(
|
||||
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
||||
)
|
||||
|
||||
# Gather them if each node has different images
|
||||
decoded = mx.concatenate(decoded, axis=0)
|
||||
if should_gather:
|
||||
decoded = mx.distributed.all_gather(decoded)
|
||||
mx.eval(decoded)
|
||||
|
||||
if args.save_raw:
|
||||
*name, suffix = args.output.split(".")
|
||||
name = ".".join(name)
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = decoded
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
for i in range(len(x)):
|
||||
im = Image.fromarray(np.array(x[i]))
|
||||
im.save(".".join([name, str(i), suffix]))
|
||||
else:
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = decoded
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
@@ -143,7 +168,7 @@ if __name__ == "__main__":
|
||||
im.save(args.output)
|
||||
|
||||
# Report the peak memory used during generation
|
||||
if args.verbose:
|
||||
if args.verbose and group.rank() == 0:
|
||||
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
|
||||
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
||||
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
# Contributing to MLX LM
|
||||
|
||||
Below are some tips to port LLMs available on Hugging Face to MLX.
|
||||
|
||||
Before starting checkout the [general contribution
|
||||
guidelines](https://github.com/ml-explore/mlx-examples/blob/main/CONTRIBUTING.md).
|
||||
|
||||
Next, from this directory, do an editable install:
|
||||
|
||||
```shell
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then check if the model has weights in the
|
||||
[safetensors](https://huggingface.co/docs/safetensors/index) format. If not
|
||||
[follow instructions](https://huggingface.co/spaces/safetensors/convert) to
|
||||
convert it.
|
||||
|
||||
After that, add the model file to the
|
||||
[`mlx_lm/models`](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/models)
|
||||
directory. You can see other examples there. We recommend starting from a model
|
||||
that is similar to the model you are porting.
|
||||
|
||||
Make sure the name of the new model file is the same as the `model_type` in the
|
||||
`config.json`, for example
|
||||
[starcoder2](https://huggingface.co/bigcode/starcoder2-7b/blob/main/config.json#L17).
|
||||
|
||||
To determine the model layer names, we suggest either:
|
||||
|
||||
- Refer to the Transformers implementation if you are familiar with the
|
||||
codebase.
|
||||
- Load the model weights and check the weight names which will tell you about
|
||||
the model structure.
|
||||
- Look at the names of the weights by inspecting `model.safetensors.index.json`
|
||||
in the Hugging Face repo.
|
||||
|
||||
To add LoRA support edit
|
||||
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
|
||||
|
||||
Finally, add a test for the new modle type to the [model
|
||||
tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py).
|
||||
|
||||
From the `llms/` directory, you can run the tests with:
|
||||
|
||||
```shell
|
||||
python -m unittest discover tests/
|
||||
```
|
||||
@@ -1,2 +0,0 @@
|
||||
include mlx_lm/requirements.txt
|
||||
recursive-include mlx_lm/ *.py
|
||||
284
llms/README.md
284
llms/README.md
@@ -1,282 +1,6 @@
|
||||
## Generate Text with LLMs and MLX
|
||||
# MOVE NOTICE
|
||||
|
||||
The easiest way to get started is to install the `mlx-lm` package:
|
||||
The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm).
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```sh
|
||||
pip install mlx-lm
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
|
||||
```sh
|
||||
conda install -c conda-forge mlx-lm
|
||||
```
|
||||
|
||||
The `mlx-lm` package also has:
|
||||
|
||||
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||
|
||||
### Quick Start
|
||||
|
||||
To generate text with an LLM use:
|
||||
|
||||
```bash
|
||||
mlx_lm.generate --prompt "Hi!"
|
||||
```
|
||||
|
||||
To chat with an LLM use:
|
||||
|
||||
```bash
|
||||
mlx_lm.chat
|
||||
```
|
||||
|
||||
This will give you a chat REPL that you can use to interact with the LLM. The
|
||||
chat context is preserved during the lifetime of the REPL.
|
||||
|
||||
Commands in `mlx-lm` typically take command line options which let you specify
|
||||
the model, sampling parameters, and more. Use `-h` to see a list of available
|
||||
options for a command, e.g.:
|
||||
|
||||
```bash
|
||||
mlx_lm.generate -h
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
You can use `mlx-lm` as a module:
|
||||
|
||||
```python
|
||||
from mlx_lm import load, generate
|
||||
|
||||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||
|
||||
prompt = "Write a story about Einstein"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
```
|
||||
|
||||
To see a description of all the arguments you can do:
|
||||
|
||||
```
|
||||
>>> help(generate)
|
||||
```
|
||||
|
||||
Check out the [generation
|
||||
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
|
||||
to see how to use the API in more detail.
|
||||
|
||||
The `mlx-lm` package also comes with functionality to quantize and optionally
|
||||
upload models to the Hugging Face Hub.
|
||||
|
||||
You can convert models using the Python API:
|
||||
|
||||
```python
|
||||
from mlx_lm import convert
|
||||
|
||||
repo = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
|
||||
|
||||
convert(repo, quantize=True, upload_repo=upload_repo)
|
||||
```
|
||||
|
||||
This will generate a 4-bit quantized Mistral 7B and upload it to the repo
|
||||
`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
|
||||
converted model in the path `mlx_model` by default.
|
||||
|
||||
To see a description of all the arguments you can do:
|
||||
|
||||
```
|
||||
>>> help(convert)
|
||||
```
|
||||
|
||||
#### Streaming
|
||||
|
||||
For streaming generation, use the `stream_generate` function. This yields
|
||||
a generation response object.
|
||||
|
||||
For example,
|
||||
|
||||
```python
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||
model, tokenizer = load(repo)
|
||||
|
||||
prompt = "Write a story about Einstein"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print(response.text, end="", flush=True)
|
||||
print()
|
||||
```
|
||||
|
||||
### Command Line
|
||||
|
||||
You can also use `mlx-lm` from the command line with:
|
||||
|
||||
```
|
||||
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
|
||||
```
|
||||
|
||||
This will download a Mistral 7B model from the Hugging Face Hub and generate
|
||||
text using the given prompt.
|
||||
|
||||
For a full list of options run:
|
||||
|
||||
```
|
||||
mlx_lm.generate --help
|
||||
```
|
||||
|
||||
To quantize a model from the command line run:
|
||||
|
||||
```
|
||||
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
|
||||
```
|
||||
|
||||
For more options run:
|
||||
|
||||
```
|
||||
mlx_lm.convert --help
|
||||
```
|
||||
|
||||
You can upload new models to Hugging Face by specifying `--upload-repo` to
|
||||
`convert`. For example, to upload a quantized Mistral-7B model to the
|
||||
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
|
||||
|
||||
```
|
||||
mlx_lm.convert \
|
||||
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
-q \
|
||||
--upload-repo mlx-community/my-4bit-mistral
|
||||
```
|
||||
|
||||
Models can also be converted and quantized directly in the
|
||||
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
||||
Face Space.
|
||||
|
||||
### Long Prompts and Generations
|
||||
|
||||
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
|
||||
|
||||
- A rotating fixed-size key-value cache.
|
||||
- Prompt caching
|
||||
|
||||
To use the rotating key-value cache pass the argument `--max-kv-size n` where
|
||||
`n` can be any integer. Smaller values like `512` will use very little RAM but
|
||||
result in worse quality. Larger values like `4096` or higher will use more RAM
|
||||
but have better quality.
|
||||
|
||||
Caching prompts can substantially speedup reusing the same long context with
|
||||
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
|
||||
|
||||
```bash
|
||||
cat prompt.txt | mlx_lm.cache_prompt \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
--prompt - \
|
||||
--prompt-cache-file mistral_prompt.safetensors
|
||||
```
|
||||
|
||||
Then use the cached prompt with `mlx_lm.generate`:
|
||||
|
||||
```
|
||||
mlx_lm.generate \
|
||||
--prompt-cache-file mistral_prompt.safetensors \
|
||||
--prompt "\nSummarize the above text."
|
||||
```
|
||||
|
||||
The cached prompt is treated as a prefix to the supplied prompt. Also notice
|
||||
when using a cached prompt, the model to use is read from the cache and need
|
||||
not be supplied explicitly.
|
||||
|
||||
Prompt caching can also be used in the Python API in order to to avoid
|
||||
recomputing the prompt. This is useful in multi-turn dialogues or across
|
||||
requests that use the same context. See the
|
||||
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
|
||||
for more usage details.
|
||||
|
||||
### Supported Models
|
||||
|
||||
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
||||
run is not supported, file an
|
||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||
submit a pull request.
|
||||
|
||||
Here are a few examples of Hugging Face models that work with this example:
|
||||
|
||||
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
|
||||
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
|
||||
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
|
||||
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
|
||||
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
|
||||
- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b)
|
||||
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
|
||||
|
||||
Most
|
||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
|
||||
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
|
||||
and
|
||||
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
|
||||
style models should work out of the box.
|
||||
|
||||
For some models (such as `Qwen` and `plamo`) the tokenizer requires you to
|
||||
enable the `trust_remote_code` option. You can do this by passing
|
||||
`--trust-remote-code` in the command line. If you don't specify the flag
|
||||
explicitly, you will be prompted to trust remote code in the terminal when
|
||||
running the model.
|
||||
|
||||
For `Qwen` models you must also specify the `eos_token`. You can do this by
|
||||
passing `--eos-token "<|endoftext|>"` in the command
|
||||
line.
|
||||
|
||||
These options can also be set in the Python API. For example:
|
||||
|
||||
```python
|
||||
model, tokenizer = load(
|
||||
"qwen/Qwen-7B",
|
||||
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
||||
)
|
||||
```
|
||||
|
||||
### Large Models
|
||||
|
||||
> [!NOTE]
|
||||
This requires macOS 15.0 or higher to work.
|
||||
|
||||
Models which are large relative to the total RAM available on the machine can
|
||||
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
|
||||
occupied by the model and cache. This requires macOS 15 or higher to
|
||||
work.
|
||||
|
||||
If you see the following warning message:
|
||||
|
||||
> [WARNING] Generating with a model that requires ...
|
||||
|
||||
then the model will likely be slow on the given machine. If the model fits in
|
||||
RAM then it can often be sped up by increasing the system wired memory limit.
|
||||
To increase the limit, set the following `sysctl`:
|
||||
|
||||
```bash
|
||||
sudo sysctl iogpu.wired_limit_mb=N
|
||||
```
|
||||
|
||||
The value `N` should be larger than the size of the model in megabytes but
|
||||
smaller than the memory size of the machine.
|
||||
The package has been removed from the MLX Examples repo. Send new contributions
|
||||
and issues to the MLX LM repo.
|
||||
|
||||
@@ -40,7 +40,7 @@ def generate(
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
prompt_tps = len(prompt) / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
@@ -19,10 +19,10 @@ class ModelArgs:
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
context_length: int
|
||||
num_key_value_heads: int = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
model_type: Optional[str] = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -54,7 +54,7 @@ class Attention(nn.Module):
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
@@ -66,7 +66,7 @@ class Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
1 / float(args.rope_scaling["factor"])
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
@@ -254,7 +254,7 @@ def translate_weight_names(name):
|
||||
return name
|
||||
|
||||
|
||||
def load(gguf_file: str, repo: str = None):
|
||||
def load(gguf_file: str, repo: Optional[str] = None):
|
||||
# If the gguf_file exists, try to load model from it.
|
||||
# Otherwise try to download and cache from the HF repo
|
||||
if not Path(gguf_file).exists():
|
||||
|
||||
@@ -7,6 +7,7 @@ import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -149,7 +150,8 @@ def quantize(weights, config, args):
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
shard: Dict[str, mx.array] = {}
|
||||
shard_size = 0
|
||||
for k, v in weights.items():
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
|
||||
@@ -23,7 +23,7 @@ class ModelArgs:
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
moe: dict = None
|
||||
moe: dict
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -91,7 +91,6 @@ class FeedForward(nn.Module):
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.num_experts = args.moe["num_experts"]
|
||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
||||
@@ -115,7 +114,6 @@ class MOEFeedForward(nn.Module):
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
# Fine-Tuning with LoRA or QLoRA
|
||||
|
||||
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
|
||||
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
|
||||
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||
|
||||
- Mistral
|
||||
- Llama
|
||||
- Phi2
|
||||
- Mixtral
|
||||
- Qwen2
|
||||
- Gemma
|
||||
- OLMo
|
||||
- MiniCPM
|
||||
- InternLM2
|
||||
|
||||
## Contents
|
||||
|
||||
- [Run](#Run)
|
||||
- [Fine-tune](#Fine-tune)
|
||||
- [Evaluate](#Evaluate)
|
||||
- [Generate](#Generate)
|
||||
- [Fuse](#Fuse)
|
||||
- [Data](#Data)
|
||||
- [Memory Issues](#Memory-Issues)
|
||||
|
||||
## Run
|
||||
|
||||
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora --help
|
||||
```
|
||||
|
||||
Note, in the following the `--model` argument can be any compatible Hugging
|
||||
Face repo or a local path to a converted model.
|
||||
|
||||
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
|
||||
[example YAML](examples/lora_config.yaml). For example:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
If command-line flags are also used, they will override the corresponding
|
||||
values in the config.
|
||||
|
||||
### Fine-tune
|
||||
|
||||
To fine-tune a model use:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--train \
|
||||
--data <path_to_data> \
|
||||
--iters 600
|
||||
```
|
||||
|
||||
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
|
||||
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
|
||||
|
||||
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||
details on the data format see the section on [Data](#Data).
|
||||
|
||||
For example, to fine-tune a Mistral 7B you can use `--model
|
||||
mistralai/Mistral-7B-v0.1`.
|
||||
|
||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||
otherwise it will use regular LoRA.
|
||||
|
||||
By default, the adapter config and learned weights are saved in `adapters/`.
|
||||
You can specify the output location with `--adapter-path`.
|
||||
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--adapter-path <path_to_adapters> \
|
||||
--data <path_to_data> \
|
||||
--test
|
||||
```
|
||||
|
||||
### Generate
|
||||
|
||||
For generation use `mlx_lm.generate`:
|
||||
|
||||
```shell
|
||||
mlx_lm.generate \
|
||||
--model <path_to_model> \
|
||||
--adapter-path <path_to_adapters> \
|
||||
--prompt "<your_model_prompt>"
|
||||
```
|
||||
|
||||
## Fuse
|
||||
|
||||
You can generate a model fused with the low-rank adapters using the
|
||||
`mlx_lm.fuse` command. This command also allows you to optionally:
|
||||
|
||||
- Upload the fused model to the Hugging Face Hub.
|
||||
- Export the fused model to GGUF. Note GGUF support is limited to Mistral,
|
||||
Mixtral, and Llama style models in fp16 precision.
|
||||
|
||||
To see supported options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse --help
|
||||
```
|
||||
|
||||
To generate the fused model run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse --model <path_to_model>
|
||||
```
|
||||
|
||||
This will by default load the adapters from `adapters/`, and save the fused
|
||||
model in the path `fused_model/`. All of these are configurable.
|
||||
|
||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||
useful for the sake of attribution and model versioning.
|
||||
|
||||
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--upload-repo mlx-community/my-lora-mistral-7b \
|
||||
--hf-path mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
To export a fused model to GGUF, run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--export-gguf
|
||||
```
|
||||
|
||||
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
|
||||
can specify the file name with `--gguf-path`.
|
||||
|
||||
## Data
|
||||
|
||||
The LoRA command expects you to provide a dataset with `--data`. The MLX
|
||||
Examples GitHub repo has an [example of the WikiSQL
|
||||
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
|
||||
correct format.
|
||||
|
||||
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
|
||||
Face.
|
||||
|
||||
### Local Datasets
|
||||
|
||||
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||
loader expects a `test.jsonl` in the data directory.
|
||||
|
||||
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
|
||||
data formats. Here are examples of these formats:
|
||||
|
||||
`chat`:
|
||||
|
||||
```jsonl
|
||||
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
|
||||
```
|
||||
|
||||
`tools`:
|
||||
|
||||
```jsonl
|
||||
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>View the expanded single data tool format</summary>
|
||||
|
||||
```jsonl
|
||||
{
|
||||
"messages": [
|
||||
{ "role": "user", "content": "What is the weather in San Francisco?" },
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_id",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and country, eg. San Francisco, USA"
|
||||
},
|
||||
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
The format for the `arguments` field in a function varies for different models.
|
||||
Common formats include JSON strings and dictionaries. The example provided
|
||||
follows the format used by
|
||||
[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples)
|
||||
and [Mistral
|
||||
AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct).
|
||||
A dictionary format is used in Hugging Face's [chat
|
||||
templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example).
|
||||
Refer to the documentation for the model you are fine-tuning for more details.
|
||||
|
||||
</details>
|
||||
|
||||
`completions`:
|
||||
|
||||
```jsonl
|
||||
{"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||
```
|
||||
|
||||
For the `completions` data format, a different key can be used for the prompt
|
||||
and completion by specifying the following in the YAML config:
|
||||
|
||||
```yaml
|
||||
prompt_feature: "input"
|
||||
completion_feature: "output"
|
||||
```
|
||||
|
||||
Here, `"input"` is the expected key instead of the default `"prompt"`, and
|
||||
`"output"` is the expected key instead of `"completion"`.
|
||||
|
||||
`text`:
|
||||
|
||||
```jsonl
|
||||
{"text": "This is an example for the model."}
|
||||
```
|
||||
|
||||
Note, the format is automatically determined by the dataset. Note also, keys
|
||||
in each line not expected by the loader will be ignored.
|
||||
|
||||
> [!NOTE]
|
||||
> Each example in the datasets must be on a single line. Do not put more than
|
||||
> one example per line and do not split an example across multiple lines.
|
||||
|
||||
### Hugging Face Datasets
|
||||
|
||||
To use Hugging Face datasets, first install the `datasets` package:
|
||||
|
||||
```
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
If the Hugging Face dataset is already in a supported format, you can specify
|
||||
it on the command line. For example, pass `--data mlx-community/wikisql` to
|
||||
train on the pre-formatted WikiwSQL data.
|
||||
|
||||
Otherwise, provide a mapping of keys in the dataset to the features MLX LM
|
||||
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
|
||||
example:
|
||||
|
||||
```yaml
|
||||
hf_dataset:
|
||||
name: "billsum"
|
||||
prompt_feature: "text"
|
||||
completion_feature: "summary"
|
||||
```
|
||||
|
||||
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||
dataset.
|
||||
|
||||
- To specify the train, valid, or test splits, set the corresponding
|
||||
`{train,valid,test}_split` argument.
|
||||
|
||||
- Arguments specified in `config` will be passed as keyword arguments to
|
||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||
|
||||
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
|
||||
[chat
|
||||
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
are used. This applies the model's chat template by default. If the model does
|
||||
not have a chat template, then Hugging Face will use a default. For example,
|
||||
the final text in the `chat` example above with Hugging Face's default template
|
||||
becomes:
|
||||
|
||||
```text
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Hello.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
How can I assistant you today.<|im_end|>
|
||||
```
|
||||
|
||||
If you are unsure of the format to use, the `chat` or `completions` are good to
|
||||
start with. For custom requirements on the format of the dataset, use the
|
||||
`text` format to assemble the content yourself.
|
||||
|
||||
## Memory Issues
|
||||
|
||||
Fine-tuning a large model with LoRA requires a machine with a decent amount
|
||||
of memory. Here are some tips to reduce memory use should you need to do so:
|
||||
|
||||
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
|
||||
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
|
||||
more details.
|
||||
|
||||
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
|
||||
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||
things down a little, but will also reduce the memory use.
|
||||
|
||||
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
|
||||
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||
needed for back propagation. It may also reduce the quality of the
|
||||
fine-tuned model if you are fine-tuning with a lot of data.
|
||||
|
||||
4. Longer examples require more memory. If it makes sense for your data, one thing
|
||||
you can do is break your examples into smaller
|
||||
sequences when making the `{train, valid, test}.jsonl` files.
|
||||
|
||||
5. Gradient checkpointing lets you trade-off memory use (less) for computation
|
||||
(more) by recomputing instead of storing intermediate values needed by the
|
||||
backward pass. You can use gradient checkpointing by passing the
|
||||
`--grad-checkpoint` flag. Gradient checkpointing will be more helpful for
|
||||
larger batch sizes or sequence lengths with smaller or quantized models.
|
||||
|
||||
For example, for a machine with 32 GB the following should run reasonably fast:
|
||||
|
||||
```
|
||||
mlx_lm.lora \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--train \
|
||||
--batch-size 1 \
|
||||
--num-layers 4 \
|
||||
--data wikisql
|
||||
```
|
||||
|
||||
The above command on an M1 Max with 32 GB runs at about 250
|
||||
tokens-per-second, using the MLX Example
|
||||
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
|
||||
data set.
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
|
||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||
@@ -1,22 +0,0 @@
|
||||
# Managing Models
|
||||
|
||||
You can use `mlx-lm` to manage models downloaded locally in your machine. They
|
||||
are stored in the Hugging Face cache.
|
||||
|
||||
Scan models:
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --scan
|
||||
```
|
||||
|
||||
Specify a `--pattern` to get info on a single or specific set of models:
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
|
||||
```
|
||||
|
||||
To delete a model (or multiple models):
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
|
||||
```
|
||||
@@ -1,50 +0,0 @@
|
||||
# Model Merging
|
||||
|
||||
You can use `mlx-lm` to merge models and upload them to the Hugging
|
||||
Face hub or save them locally for LoRA fine tuning.
|
||||
|
||||
The main command is `mlx_lm.merge`:
|
||||
|
||||
```shell
|
||||
mlx_lm.merge --config config.yaml
|
||||
```
|
||||
|
||||
The merged model will be saved by default in `mlx_merged_model`. To see a
|
||||
full list of options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.merge --help
|
||||
```
|
||||
|
||||
Here is an example `config.yaml`:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- OpenPipe/mistral-ft-optimized-1218
|
||||
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||
method: slerp
|
||||
parameters:
|
||||
t:
|
||||
- filter: self_attn
|
||||
value: [0, 0.5, 0.3, 0.7, 1]
|
||||
- filter: mlp
|
||||
value: [1, 0.5, 0.7, 0.3, 0]
|
||||
- value: 0.5
|
||||
```
|
||||
|
||||
The `models` field is a list of Hugging Face repo ids. The first model in the
|
||||
list is treated as the base model into which the remaining models are merged.
|
||||
|
||||
The `method` field is the merging method. Right now `slerp` is the only
|
||||
supported method.
|
||||
|
||||
The `parameters` are the corresponding parameters for the given `method`.
|
||||
Each parameter is a list with `filter` determining which layer the parameter
|
||||
applies to and `value` determining the actual value used. The last item in
|
||||
the list without a `filter` field is the default.
|
||||
|
||||
If `value` is a list, it specifies the start and end values for the
|
||||
corresponding segment of blocks. In the example above, the models have 32
|
||||
blocks. For blocks 1-8, the layers with `self_attn` in the name will use the
|
||||
values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16)
|
||||
will use `np.linspace(0.5, 0.3, 8)`, and so on.
|
||||
@@ -1,10 +0,0 @@
|
||||
## Generate Text with MLX and :hugs: Hugging Face
|
||||
|
||||
This an example of large language model text generation that can pull models from
|
||||
the Hugging Face Hub.
|
||||
|
||||
For more information on this example, see the [README](../README.md) in the
|
||||
parent directory.
|
||||
|
||||
This package also supports fine tuning with LoRA or QLoRA. For more information
|
||||
see the [LoRA documentation](LORA.md).
|
||||
@@ -1,131 +0,0 @@
|
||||
# HTTP Model Server
|
||||
|
||||
You use `mlx-lm` to make an HTTP API for generating text with any supported
|
||||
model. The HTTP API is intended to be similar to the [OpenAI chat
|
||||
API](https://platform.openai.com/docs/api-reference).
|
||||
|
||||
> [!NOTE]
|
||||
> The MLX LM server is not recommended for production as it only implements
|
||||
> basic security checks.
|
||||
|
||||
Start the server with:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --model <path_to_model_or_hf_repo>
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
|
||||
```
|
||||
|
||||
This will start a text generation server on port `8080` of the `localhost`
|
||||
using Mistral 7B instruct. The model will be downloaded from the provided
|
||||
Hugging Face repo if it is not already in the local cache.
|
||||
|
||||
To see a full list of options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --help
|
||||
```
|
||||
|
||||
You can make a request to the model by running:
|
||||
|
||||
```shell
|
||||
curl localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
### Request Fields
|
||||
|
||||
- `messages`: An array of message objects representing the conversation
|
||||
history. Each message object should have a role (e.g. user, assistant) and
|
||||
content (the message text).
|
||||
|
||||
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
||||
the generated prompt. If not provided, the default mappings are used.
|
||||
|
||||
- `stop`: (Optional) An array of strings or a single string. These are
|
||||
sequences of tokens on which the generation should stop.
|
||||
|
||||
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
||||
to generate. Defaults to `100`.
|
||||
|
||||
- `stream`: (Optional) A boolean indicating if the response should be
|
||||
streamed. If true, responses are sent as they are generated. Defaults to
|
||||
false.
|
||||
|
||||
- `temperature`: (Optional) A float specifying the sampling temperature.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `repetition_context_size`: (Optional) The size of the context window for
|
||||
applying repetition penalty. Defaults to `20`.
|
||||
|
||||
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
||||
values. Defaults to `None`.
|
||||
|
||||
- `logprobs`: (Optional) An integer specifying the number of top tokens and
|
||||
corresponding log probabilities to return for each output in the generated
|
||||
sequence. If set, this can be any value between 1 and 10, inclusive.
|
||||
|
||||
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
|
||||
If the path is local is must be relative to the directory the server was
|
||||
started in.
|
||||
|
||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||
relative to the directory the server was started in.
|
||||
|
||||
### Response Fields
|
||||
|
||||
- `id`: A unique identifier for the chat.
|
||||
|
||||
- `system_fingerprint`: A unique identifier for the system.
|
||||
|
||||
- `object`: Any of "chat.completion", "chat.completion.chunk" (for
|
||||
streaming), or "text.completion".
|
||||
|
||||
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
|
||||
|
||||
- `created`: A time-stamp for when the request was processed.
|
||||
|
||||
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
|
||||
- `index`: The index in the list.
|
||||
- `logprobs`: A dictionary containing the fields:
|
||||
- `token_logprobs`: A list of the log probabilities for the generated
|
||||
tokens.
|
||||
- `tokens`: A list of the generated token ids.
|
||||
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
|
||||
top tokens (if requested) with their corresponding probabilities.
|
||||
- `finish_reason`: The reason the completion ended. This can be either of
|
||||
`"stop"` or `"length"`.
|
||||
- `message`: The text response from the model.
|
||||
|
||||
- `usage`: A dictionary containing the fields:
|
||||
- `prompt_tokens`: The number of prompt tokens processed.
|
||||
- `completion_tokens`: The number of tokens generated.
|
||||
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
|
||||
|
||||
### List Models
|
||||
|
||||
Use the `v1/models` endpoint to list available models:
|
||||
|
||||
```shell
|
||||
curl localhost:8080/v1/models -H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
This will return a list of locally available models where each model in the
|
||||
list contains the following fields:
|
||||
|
||||
- `id`: The Hugging Face repo id.
|
||||
- `created`: A time-stamp representing the model creation time.
|
||||
@@ -1,37 +0,0 @@
|
||||
### Packaging for PyPI
|
||||
|
||||
Install `build` and `twine`:
|
||||
|
||||
```
|
||||
pip install --user --upgrade build
|
||||
pip install --user --upgrade twine
|
||||
```
|
||||
|
||||
Generate the source distribution and wheel:
|
||||
|
||||
```
|
||||
python -m build
|
||||
```
|
||||
|
||||
> [!warning]
|
||||
> Use a test server first
|
||||
|
||||
#### Test Upload
|
||||
|
||||
Upload to test server:
|
||||
|
||||
```
|
||||
python -m twine upload --repository testpypi dist/*
|
||||
```
|
||||
|
||||
Install from test server and check that it works:
|
||||
|
||||
```
|
||||
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm
|
||||
```
|
||||
|
||||
#### Upload
|
||||
|
||||
```
|
||||
python -m twine upload dist/*
|
||||
```
|
||||
@@ -1,9 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import os
|
||||
|
||||
from ._version import __version__
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
from .utils import convert, generate, load, stream_generate
|
||||
@@ -1,3 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.21.0"
|
||||
@@ -1,161 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||
from .utils import generate_step, load
|
||||
|
||||
DEFAULT_QUANTIZED_KV_START = 5000
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Cache the state of a prompt to be reused with mlx_lm.generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Enable trusting remote code for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos-token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="End of sequence token for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-chat-template",
|
||||
action="store_true",
|
||||
help="Use the raw prompt without the tokenizer's chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-default-chat-template",
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the maximum key-value cache size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-cache-file",
|
||||
help="The file to save the prompt cache in",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
required=True,
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-bits",
|
||||
type=int,
|
||||
help="Number of bits for KV cache quantization. "
|
||||
"Defaults to no quantization.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
help="Group size for KV cache quantization.",
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized-kv-start",
|
||||
help="When --kv-bits is set, start quantizing the KV cache "
|
||||
"from this step onwards.",
|
||||
type=int,
|
||||
default=DEFAULT_QUANTIZED_KV_START,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
|
||||
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=False, continue_final_message=True
|
||||
)
|
||||
|
||||
else:
|
||||
prompt = tokenizer.encode(args.prompt)
|
||||
|
||||
cache = make_prompt_cache(model, args.max_kv_size)
|
||||
y = mx.array(prompt)
|
||||
|
||||
# Process the prompt
|
||||
start = time.time()
|
||||
max_msg_len = 0
|
||||
|
||||
def callback(processed, total_tokens):
|
||||
current = time.time()
|
||||
speed = processed / (current - start)
|
||||
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
||||
nonlocal max_msg_len
|
||||
max_msg_len = max(max_msg_len, len(msg))
|
||||
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
||||
|
||||
for _ in generate_step(
|
||||
y,
|
||||
model,
|
||||
max_tokens=0,
|
||||
prompt_cache=cache,
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
prompt_progress_callback=callback,
|
||||
):
|
||||
pass
|
||||
|
||||
print()
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
|
||||
|
||||
print("Saving...")
|
||||
metadata = {}
|
||||
metadata["model"] = args.model
|
||||
metadata["chat_template"] = tokenizer.chat_template
|
||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,89 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import make_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import load, stream_generate
|
||||
|
||||
DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MAX_TOKENS = 256
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
help="Set the maximum key-value cache size",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||
while True:
|
||||
query = input(">> ")
|
||||
if query == "q":
|
||||
break
|
||||
messages = [{"role": "user", "content": query}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
for response in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
sampler=make_sampler(args.temp, args.top_p),
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
print(response.text, flush=True, end="")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,62 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
from .utils import convert
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
|
||||
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
||||
parser.add_argument(
|
||||
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the non-quantized parameters.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dequantize",
|
||||
help="Dequantize a quantized model.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
convert(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,392 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Adapted from a PyTorch implementation by David Grangier
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import lm_eval
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.api.registry import register_model
|
||||
from tqdm import tqdm
|
||||
|
||||
from .models.cache import make_prompt_cache
|
||||
from .utils import load, stream_generate
|
||||
|
||||
PAD = 0
|
||||
|
||||
|
||||
def _len_longest_common_prefix(a, b):
|
||||
l = 0
|
||||
for item_a, item_b in zip(a, b):
|
||||
if item_a != item_b:
|
||||
break
|
||||
l += 1
|
||||
return l
|
||||
|
||||
|
||||
def _rstrip_until(s, untils):
|
||||
"""Limit a string <s> to the first occurrence of any substring in untils."""
|
||||
l = len(s)
|
||||
f = [s.find(u) for u in untils]
|
||||
f = [l if x < 0 else x for x in f]
|
||||
return s[: min(f)]
|
||||
|
||||
|
||||
def _pad_inputs(
|
||||
inputs,
|
||||
maxlen,
|
||||
genlen=0,
|
||||
pad_left=False,
|
||||
pad_multiple=32,
|
||||
truncate=False,
|
||||
):
|
||||
# pad the prompts to the left with at least genlen tokens.
|
||||
actual_maxlen = max(len(p) for p in inputs) + genlen
|
||||
if actual_maxlen > maxlen:
|
||||
if not truncate:
|
||||
raise ValueError("Inputs are too long.")
|
||||
else: # drop begining
|
||||
actual_maxlen = maxlen
|
||||
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
|
||||
if pad_multiple > 0:
|
||||
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
|
||||
maxlen *= pad_multiple
|
||||
assert PAD == 0
|
||||
lr = np.array((1, 0) if pad_left else (0, 1))
|
||||
return np.stack(
|
||||
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
|
||||
@register_model("mlxlm")
|
||||
class MLXLM(LM):
|
||||
def __init__(
|
||||
self,
|
||||
path_or_hf_repo: str,
|
||||
batch_size: int = 16,
|
||||
max_tokens: Optional[int] = None,
|
||||
use_chat_template: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._batch_size = batch_size
|
||||
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||
self.use_chat_template = use_chat_template or (
|
||||
self.tokenizer.chat_template is not None
|
||||
)
|
||||
|
||||
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
||||
if tokenize:
|
||||
inputs = self._tokenize(inputs)
|
||||
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
||||
inputs = mx.array(inputs)
|
||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||
|
||||
cache = make_prompt_cache(self._model)
|
||||
|
||||
mask = targets != PAD
|
||||
|
||||
scores, is_greedy = [], []
|
||||
for i in range(0, inputs.shape[1], step_size):
|
||||
logits = self._model(inputs[:, i : i + step_size], cache=cache)
|
||||
|
||||
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
||||
score = mx.take_along_axis(
|
||||
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
||||
)[..., 0]
|
||||
ig = mask[:, i : i + step_size] * (
|
||||
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
||||
)
|
||||
|
||||
mx.eval(score, ig)
|
||||
mx.metal.clear_cache()
|
||||
|
||||
is_greedy.append(ig)
|
||||
scores.append(score)
|
||||
|
||||
scores = mx.concatenate(scores, axis=1)
|
||||
is_greedy = mx.concatenate(is_greedy, axis=1)
|
||||
|
||||
return scores, mask.sum(axis=-1), is_greedy
|
||||
|
||||
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
|
||||
# sort by length to get batches with little padding.
|
||||
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
|
||||
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
|
||||
sorted_spans = None
|
||||
if score_spans is not None:
|
||||
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
|
||||
|
||||
results = []
|
||||
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
|
||||
batch = sorted_inputs[i : i + self._batch_size]
|
||||
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
|
||||
for j in range(len(batch)):
|
||||
if sorted_spans is None: # full sequence score
|
||||
mask = mx.arange(scores[j].shape[-1]) < length
|
||||
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
|
||||
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
|
||||
else: # subsequence score
|
||||
start, end = sorted_spans[i + j]
|
||||
score = scores[j][start:end].astype(mx.float32).sum()
|
||||
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
||||
length = end - start
|
||||
|
||||
results.append((score.item(), ig.item(), length))
|
||||
|
||||
# reorder the outputs
|
||||
inv_sort = np.argsort(sorted_indices)
|
||||
results = [results[inv_sort[i]] for i in range(len(results))]
|
||||
|
||||
return results
|
||||
|
||||
def _tokenize(self, texts):
|
||||
return [
|
||||
tuple(
|
||||
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
|
||||
)
|
||||
for t in texts
|
||||
]
|
||||
|
||||
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
|
||||
"""Compute log-likelihood of generating a continuation from a context.
|
||||
Downstream tasks should attempt to use loglikelihood instead of other
|
||||
LM calls whenever possible.
|
||||
:param requests: list[Instance]
|
||||
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
||||
`context: str`
|
||||
Context string. Implementations of LM must be able to handle an
|
||||
empty context string.
|
||||
`continuation: str`
|
||||
The continuation over which log likelihood will be calculated. If
|
||||
there is a word boundary, the space should be in the continuation.
|
||||
For example, context="hello" continuation=" world" is correct.
|
||||
:return: list[tuple[float, bool]]
|
||||
A list of pairs (logprob, isgreedy)
|
||||
`logprob: float`
|
||||
The log probability of `continuation`.
|
||||
`isgreedy`:
|
||||
Whether `continuation` would be generated by greedy sampling from `context`.
|
||||
"""
|
||||
logging.info("Estimating loglikelihood for %d pairs." % len(requests))
|
||||
|
||||
# tokenize prefix and prefix + completion for all requests.
|
||||
tokenized = self._tokenize(
|
||||
[t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]]
|
||||
)
|
||||
|
||||
# max length (prefix + completion) and longest common prefix per question.
|
||||
length_stats = {}
|
||||
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
|
||||
max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8))
|
||||
length_stats[prefix] = (
|
||||
max(max_completed_l, len(completed)),
|
||||
min(min_prefix_l, _len_longest_common_prefix(prefix, completed)),
|
||||
)
|
||||
|
||||
# truncate requests for completed sequences longer than model context.
|
||||
shortened = []
|
||||
completion_spans = []
|
||||
long_completions = 0
|
||||
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
|
||||
max_completed_l, prefix_l = length_stats[prefix]
|
||||
# compute truncation length
|
||||
truncation = max(0, max_completed_l - self._max_tokens - 1)
|
||||
prefix_l = prefix_l - truncation
|
||||
if prefix_l <= 0:
|
||||
# completion too long, prefix is eliminated for some requests.
|
||||
long_completions += 1
|
||||
truncation = max(0, len(completed) - self._max_tokens - 1)
|
||||
prefix_l = 1
|
||||
# truncate the completed sequence
|
||||
completed = completed[truncation:]
|
||||
shortened.append(completed)
|
||||
# scores do not include initial bos, substract 1 to span bounds
|
||||
completion_spans.append((prefix_l - 1, len(completed) - 1))
|
||||
|
||||
if long_completions > 0:
|
||||
logging.info(
|
||||
f"Prefix eliminated for {long_completions} requests with "
|
||||
+ "completion longer than context."
|
||||
)
|
||||
|
||||
# model scoring, returns num_requests x (logp, is_greedy, length).
|
||||
results = self._loglikelihood(
|
||||
shortened,
|
||||
score_spans=completion_spans,
|
||||
tokenize=False,
|
||||
)
|
||||
return [(r[0], r[1] == r[2]) for r in results]
|
||||
|
||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
||||
|
||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||
- We will use the full max context length of the model.
|
||||
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
||||
the max context length.
|
||||
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
||||
which may simply concatenate multiple documents together.
|
||||
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
||||
multiple chunks, the last input will still a full-sized context.
|
||||
Example:
|
||||
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
||||
Prefix: EOT
|
||||
Max context length: 4
|
||||
Resulting input/prediction pairs:
|
||||
INPUT: EOT 0 1 2
|
||||
PRED: 0 1 2 3
|
||||
INPUT: 3 4 5 6
|
||||
PRED: 4 5 6 7
|
||||
INPUT: 5 6 7 8
|
||||
PRED: 8 9
|
||||
Observe that:
|
||||
1. Each token is predicted exactly once
|
||||
2. For the last pair, we provide the full context, but only score the last two tokens
|
||||
:param requests: list[Instance]
|
||||
A list of Instance objects with property `args` which returns a tuple (context,).
|
||||
string: str
|
||||
String for which we are computing overall loglikelihood
|
||||
:return: list[tuple[float]]
|
||||
A list of tuples (logprob,)
|
||||
logprob: float
|
||||
The log probability of `context` conditioned on the EOT token.
|
||||
"""
|
||||
logging.info(
|
||||
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
||||
)
|
||||
inputs = [req.args[0] for req in requests]
|
||||
return [t[0] for t in self._loglikelihood(inputs)]
|
||||
|
||||
def generate_until(self, requests) -> list[str]:
|
||||
"""Generate greedily until a stopping sequence
|
||||
:param requests: list[Instance]
|
||||
A list of Instance objects with property `args` which returns a tuple (context, until).
|
||||
context: str
|
||||
Context string
|
||||
until: [str]
|
||||
The string sequences to generate until. These string sequences
|
||||
may each span across multiple tokens, or may be part of one token.
|
||||
:return: list[str]
|
||||
A list of strings continuation
|
||||
continuation: str
|
||||
The generated continuation.
|
||||
"""
|
||||
logging.info("Generating continuation for %d sequences." % len(requests))
|
||||
contexts, options = zip(*[req.args for req in requests])
|
||||
# contrary to the doc the second element of the tuple contains
|
||||
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
|
||||
keys = list(options[0].keys())
|
||||
assert "until" in keys
|
||||
untils = [x["until"] for x in options]
|
||||
completions = []
|
||||
|
||||
for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
|
||||
context = self._tokenize(context)
|
||||
max_tokens = min(
|
||||
self._max_tokens,
|
||||
self.tokenizer.model_max_length - len(context),
|
||||
)
|
||||
text = ""
|
||||
for response in stream_generate(
|
||||
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
|
||||
):
|
||||
text += response.text
|
||||
if any(u in text for u in until):
|
||||
text = _rstrip_until(text, until)
|
||||
completions.append(text)
|
||||
break
|
||||
else:
|
||||
completions.append(text)
|
||||
return completions
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
"Evaluate an MLX model using lm-evaluation-harness."
|
||||
)
|
||||
parser.add_argument("--model", help="Model to evaluate", required=True)
|
||||
parser.add_argument("--tasks", nargs="+", required=True)
|
||||
parser.add_argument(
|
||||
"--output-dir", default=".", help="Output directory for result files."
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
||||
parser.add_argument("--num-shots", type=int, default=0, help="Number of shots")
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
default=1.0,
|
||||
help="Limit the number of examples per task.",
|
||||
type=float,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
||||
parser.add_argument(
|
||||
"--fewshot-as-multiturn",
|
||||
action="store_true",
|
||||
help="Whether to provide the fewshot examples as a multiturn "
|
||||
"conversation or a single user turn.",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether to apply a chat template to the prompt. If "
|
||||
"the model has a chat template, this defaults to `True`, "
|
||||
"otherwise `False`.",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Silence tokenizer warnings
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
lm = MLXLM(
|
||||
args.model,
|
||||
batch_size=args.batch_size,
|
||||
max_tokens=args.max_tokens,
|
||||
use_chat_template=args.apply_chat_template,
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model=lm,
|
||||
tasks=args.tasks,
|
||||
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
||||
apply_chat_template=lm.use_chat_template,
|
||||
num_fewshot=args.num_shots,
|
||||
limit=args.limit,
|
||||
random_seed=args.seed,
|
||||
numpy_random_seed=args.seed,
|
||||
torch_random_seed=args.seed,
|
||||
fewshot_random_seed=args.seed,
|
||||
)
|
||||
|
||||
model_name = args.model.replace("/", "_")
|
||||
task_names = "_".join(args.tasks)
|
||||
ver = version("lm_eval")
|
||||
filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json"
|
||||
output_path = output_dir / filename
|
||||
output_path.write_text(json.dumps(results["results"], indent=4))
|
||||
print("Results:")
|
||||
for result in results["results"].values():
|
||||
print(json.dumps(result, indent=4))
|
||||
@@ -1,48 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
An example of a multi-turn chat with prompt caching.
|
||||
"""
|
||||
|
||||
from mlx_lm import generate, load
|
||||
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
|
||||
|
||||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||
|
||||
# Make the initial prompt cache for the model
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# User turn
|
||||
prompt = "Hi my name is <Name>."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
temp=0.0,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
# User turn
|
||||
prompt = "What's my name?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
# Save the prompt cache to disk to reuse it at a later time
|
||||
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
|
||||
|
||||
# Load the prompt cache from disk
|
||||
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")
|
||||
@@ -1,33 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from mlx_lm import generate, load
|
||||
|
||||
# Specify the checkpoint
|
||||
checkpoint = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
|
||||
# Load the corresponding model and tokenizer
|
||||
model, tokenizer = load(path_or_hf_repo=checkpoint)
|
||||
|
||||
# Specify the prompt and conversation history
|
||||
prompt = "Why is the sky blue?"
|
||||
conversation = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Transform the prompt into the chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Specify the maximum number of tokens
|
||||
max_tokens = 1_000
|
||||
|
||||
# Specify if tokens and timing information will be printed
|
||||
verbose = True
|
||||
|
||||
# Generate a response with the specified settings
|
||||
response = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
verbose=verbose,
|
||||
)
|
||||
@@ -1,80 +0,0 @@
|
||||
# The path to the local model directory or Hugging Face repo.
|
||||
model: "mlx_model"
|
||||
|
||||
# Whether or not to train (boolean)
|
||||
train: true
|
||||
|
||||
# The fine-tuning method: "lora", "dora", or "full".
|
||||
fine_tune_type: lora
|
||||
|
||||
# Directory with {train, valid, test}.jsonl files
|
||||
data: "/path/to/training/data"
|
||||
|
||||
# The PRNG seed
|
||||
seed: 0
|
||||
|
||||
# Number of layers to fine-tune
|
||||
num_layers: 16
|
||||
|
||||
# Minibatch size.
|
||||
batch_size: 4
|
||||
|
||||
# Iterations to train for.
|
||||
iters: 1000
|
||||
|
||||
# Number of validation batches, -1 uses the entire validation set.
|
||||
val_batches: 25
|
||||
|
||||
# Adam learning rate.
|
||||
learning_rate: 1e-5
|
||||
|
||||
# Number of training steps between loss reporting.
|
||||
steps_per_report: 10
|
||||
|
||||
# Number of training steps between validations.
|
||||
steps_per_eval: 200
|
||||
|
||||
# Load path to resume training with the given adapter weights.
|
||||
resume_adapter_file: null
|
||||
|
||||
# Save/load path for the trained adapter weights.
|
||||
adapter_path: "adapters"
|
||||
|
||||
# Save the model every N iterations.
|
||||
save_every: 100
|
||||
|
||||
# Evaluate on the test set after training
|
||||
test: false
|
||||
|
||||
# Number of test set batches, -1 uses the entire test set.
|
||||
test_batches: 100
|
||||
|
||||
# Maximum sequence length.
|
||||
max_seq_length: 2048
|
||||
|
||||
# Use gradient checkpointing to reduce memory use.
|
||||
grad_checkpoint: false
|
||||
|
||||
# LoRA parameters can only be specified in a config file
|
||||
lora_parameters:
|
||||
# The layer keys to apply LoRA to.
|
||||
# These will be applied for the last lora_layers
|
||||
keys: ["self_attn.q_proj", "self_attn.v_proj"]
|
||||
rank: 8
|
||||
scale: 20.0
|
||||
dropout: 0.0
|
||||
|
||||
# Schedule can only be specified in a config file, uncomment to use.
|
||||
#lr_schedule:
|
||||
# name: cosine_decay
|
||||
# warmup: 100 # 0 for no warmup
|
||||
# warmup_init: 1e-7 # 0 if not specified
|
||||
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
|
||||
|
||||
#hf_dataset:
|
||||
# name: "billsum"
|
||||
# train_split: "train[:1000]"
|
||||
# valid_split: "train[-100:]"
|
||||
# prompt_feature: "text"
|
||||
# completion_feature: "summary"
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
models:
|
||||
- OpenPipe/mistral-ft-optimized-1218
|
||||
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||
method: slerp
|
||||
parameters:
|
||||
t:
|
||||
- filter: self_attn
|
||||
value: [0, 0.5, 0.3, 0.7, 1]
|
||||
- filter: mlp
|
||||
value: [1, 0.5, 0.7, 0.3, 0]
|
||||
- value: 0.5
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
|
||||
```
|
||||
mlx.launch \
|
||||
--hostfile /path/to/hosts.txt \
|
||||
--backend mpi \
|
||||
/path/to/pipeline_generate.py \
|
||||
--prompt "hello world"
|
||||
```
|
||||
|
||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||
documentation:
|
||||
|
||||
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import load, stream_generate
|
||||
from mlx_lm.utils import load_model, load_tokenizer
|
||||
|
||||
|
||||
def download(repo: str, allow_patterns: list[str]) -> Path:
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def shard_and_load(repo):
|
||||
# Get model path with everything but weight safetensors
|
||||
model_path = download(
|
||||
args.model,
|
||||
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||
)
|
||||
|
||||
# Lazy load and shard model to figure out
|
||||
# which weights we need
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
|
||||
group = mx.distributed.init(backend="mpi")
|
||||
rank = group.rank()
|
||||
model.model.pipeline(group)
|
||||
|
||||
# Figure out which files we need for the local shard
|
||||
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
||||
weight_index = json.load(fid)["weight_map"]
|
||||
|
||||
local_files = set()
|
||||
for k, _ in tree_flatten(model.parameters()):
|
||||
local_files.add(weight_index[k])
|
||||
|
||||
# Download weights for local shard
|
||||
download(args.model, allow_patterns=local_files)
|
||||
|
||||
# Load and shard the model, and load the weights
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
model.model.pipeline(group)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Synchronize processes before generation to avoid timeout if downloading
|
||||
# model for the first time.
|
||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx-community/DeepSeek-R1-3bit",
|
||||
help="HF repo or path to local model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
default="Write a quicksort in C++.",
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
group = mx.distributed.init(backend="mpi")
|
||||
rank = group.rank()
|
||||
|
||||
def rprint(*args, **kwargs):
|
||||
if rank == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
model, tokenizer = shard_and_load(args.model)
|
||||
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
for response in stream_generate(
|
||||
model, tokenizer, prompt, max_tokens=args.max_tokens
|
||||
):
|
||||
rprint(response.text, end="", flush=True)
|
||||
|
||||
rprint()
|
||||
rprint("=" * 10)
|
||||
rprint(
|
||||
f"Prompt: {response.prompt_tokens} tokens, "
|
||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
rprint(
|
||||
f"Generation: {response.generation_tokens} tokens, "
|
||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||
@@ -1,130 +0,0 @@
|
||||
import argparse
|
||||
import glob
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from .gguf import convert_to_gguf
|
||||
from .tuner.dora import DoRAEmbedding, DoRALinear
|
||||
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
||||
from .tuner.utils import dequantize, load_adapters
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
save_config,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fuse fine-tuned adapters into the base model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
default="fused_model",
|
||||
help="The path to save the fused model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
default="adapters",
|
||||
help="Path to the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--de-quantize",
|
||||
help="Generate a de-quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-gguf",
|
||||
help="Export model weights in GGUF format.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gguf-path",
|
||||
help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.",
|
||||
default="ggml-model-f16.gguf",
|
||||
type=str,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print("Loading pretrained model")
|
||||
args = parse_arguments()
|
||||
|
||||
model_path = get_model_path(args.model)
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
model.freeze()
|
||||
model = load_adapters(model, args.adapter_path)
|
||||
|
||||
fused_linears = [
|
||||
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
||||
]
|
||||
|
||||
if fused_linears:
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
|
||||
if args.de_quantize:
|
||||
print("De-quantizing model")
|
||||
model = dequantize(model)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
save_path = Path(args.save_path)
|
||||
|
||||
save_weights(save_path, weights)
|
||||
|
||||
py_files = glob.glob(str(model_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, save_path)
|
||||
|
||||
tokenizer.save_pretrained(save_path)
|
||||
|
||||
if args.de_quantize:
|
||||
config.pop("quantization", None)
|
||||
|
||||
save_config(config, config_path=save_path / "config.json")
|
||||
|
||||
if args.export_gguf:
|
||||
model_type = config["model_type"]
|
||||
if model_type not in ["llama", "mixtral", "mistral"]:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} not supported for GGUF conversion."
|
||||
)
|
||||
convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
|
||||
|
||||
if args.upload_repo is not None:
|
||||
hf_path = args.hf_path or (
|
||||
args.model if not Path(args.model).exists() else None
|
||||
)
|
||||
if hf_path is None:
|
||||
raise ValueError(
|
||||
"Must provide original Hugging Face repo to upload local model."
|
||||
)
|
||||
upload_to_hub(args.save_path, args.upload_repo, hf_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,257 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_MIN_P = 0.0
|
||||
DEFAULT_MIN_TOKENS_TO_KEEP = 1
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
DEFAULT_QUANTIZED_KV_START = 5000
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
return string.lower() not in ["false", "f"]
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="LLM inference script")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help=(
|
||||
"The path to the local model directory or Hugging Face repo. "
|
||||
f"If no model is specified, then {DEFAULT_MODEL} is used."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-eos-token",
|
||||
type=str,
|
||||
default=(),
|
||||
nargs="+",
|
||||
help="Add tokens in the list of eos tokens that stop generation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system-prompt",
|
||||
default=None,
|
||||
help="System prompt to be used for the chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
default=DEFAULT_PROMPT,
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-tokens-to-keep",
|
||||
type=int,
|
||||
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||
help="Minimum tokens to keep for min-p sampling.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||
parser.add_argument(
|
||||
"--ignore-chat-template",
|
||||
action="store_true",
|
||||
help="Use the raw prompt without the tokenizer's chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-default-chat-template",
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
help="Set the maximum key-value cache size",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-cache-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-bits",
|
||||
type=int,
|
||||
help="Number of bits for KV cache quantization. "
|
||||
"Defaults to no quantization.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
help="Group size for KV cache quantization.",
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized-kv-start",
|
||||
help="When --kv-bits is set, start quantizing the KV cache "
|
||||
"from this step onwards.",
|
||||
type=int,
|
||||
default=DEFAULT_QUANTIZED_KV_START,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-model",
|
||||
type=str,
|
||||
help="A model to be used for speculative decoding.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-draft-tokens",
|
||||
type=int,
|
||||
help="Number of tokens to draft when using speculative decoding.",
|
||||
default=2,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file,
|
||||
return_metadata=True,
|
||||
)
|
||||
if isinstance(prompt_cache[0], QuantizedKVCache):
|
||||
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
|
||||
raise ValueError(
|
||||
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
|
||||
)
|
||||
if args.kv_group_size != prompt_cache[0].group_size:
|
||||
raise ValueError(
|
||||
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
|
||||
)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
|
||||
model_path = args.model
|
||||
if using_cache:
|
||||
if model_path is None:
|
||||
model_path = metadata["model"]
|
||||
elif model_path != metadata["model"]:
|
||||
raise ValueError(
|
||||
f"Providing a different model ({model_path}) than that "
|
||||
f"used to create the prompt cache ({metadata['model']}) "
|
||||
"is an error."
|
||||
)
|
||||
model_path = model_path or DEFAULT_MODEL
|
||||
|
||||
model, tokenizer = load(
|
||||
model_path,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
for eos_token in args.extra_eos_token:
|
||||
tokenizer.add_eos_token(eos_token)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
if args.system_prompt is not None:
|
||||
messages = [{"role": "system", "content": args.system_prompt}]
|
||||
else:
|
||||
messages = []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
# stored kv cache.
|
||||
if using_cache:
|
||||
messages[-1]["content"] = "<query>"
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
else:
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
if args.draft_model is not None:
|
||||
draft_model, draft_tokenizer = load(args.draft_model)
|
||||
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
|
||||
raise ValueError("Draft model tokenizer does not match model tokenizer.")
|
||||
else:
|
||||
draft_model = None
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
sampler=sampler,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=args.num_draft_tokens,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,314 +0,0 @@
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class TokenType(IntEnum):
|
||||
NORMAL = 1
|
||||
UNKNOWN = 2
|
||||
CONTROL = 3
|
||||
USER_DEFINED = 4
|
||||
UNUSED = 5
|
||||
BYTE = 6
|
||||
|
||||
|
||||
class GGMLFileType(IntEnum):
|
||||
GGML_TYPE_F16 = 1
|
||||
|
||||
|
||||
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
|
||||
class HfVocab:
|
||||
def __init__(
|
||||
self,
|
||||
fname_tokenizer: Path,
|
||||
fname_added_tokens: Optional[Union[Path, None]] = None,
|
||||
) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
fname_tokenizer,
|
||||
cache_dir=fname_tokenizer,
|
||||
local_files_only=True,
|
||||
)
|
||||
self.added_tokens_list = []
|
||||
self.added_tokens_dict = dict()
|
||||
self.added_tokens_ids = set()
|
||||
for tok, tokidx in sorted(
|
||||
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
|
||||
):
|
||||
if tokidx >= self.tokenizer.vocab_size:
|
||||
self.added_tokens_list.append(tok)
|
||||
self.added_tokens_dict[tok] = tokidx
|
||||
self.added_tokens_ids.add(tokidx)
|
||||
self.specials = {
|
||||
tok: self.tokenizer.get_vocab()[tok]
|
||||
for tok in self.tokenizer.all_special_tokens
|
||||
}
|
||||
self.special_ids = set(self.tokenizer.all_special_ids)
|
||||
self.vocab_size_base = self.tokenizer.vocab_size
|
||||
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
self.fname_added_tokens = fname_added_tokens
|
||||
|
||||
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
reverse_vocab = {
|
||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||
}
|
||||
for token_id in range(self.vocab_size_base):
|
||||
if token_id in self.added_tokens_ids:
|
||||
continue
|
||||
token_text = reverse_vocab[token_id]
|
||||
yield token_text, self.get_token_score(token_id), self.get_token_type(
|
||||
token_id, token_text, self.special_ids
|
||||
)
|
||||
|
||||
def get_token_type(
|
||||
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||
) -> TokenType:
|
||||
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
return TokenType.BYTE
|
||||
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
|
||||
|
||||
def get_token_score(self, token_id: int) -> float:
|
||||
return -1000.0
|
||||
|
||||
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
if text in self.specials:
|
||||
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
|
||||
score = self.get_token_score(self.specials[text])
|
||||
else:
|
||||
toktype = TokenType.USER_DEFINED
|
||||
score = -1000.0
|
||||
yield text, score, toktype
|
||||
|
||||
def has_newline_token(self):
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||
|
||||
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
yield from self.hf_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
@staticmethod
|
||||
def load(path: Path) -> "HfVocab":
|
||||
added_tokens_path = path.parent / "added_tokens.json"
|
||||
return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
||||
|
||||
|
||||
def translate_weight_names(name):
|
||||
name = name.replace("model.layers.", "blk.")
|
||||
# for mixtral gate
|
||||
name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
|
||||
# for mixtral experts ffns
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight"
|
||||
replacement = r"ffn_gate.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight"
|
||||
replacement = r"ffn_down.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight"
|
||||
replacement = r"ffn_up.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
|
||||
name = name.replace("mlp.gate_proj", "ffn_gate")
|
||||
name = name.replace("mlp.down_proj", "ffn_down")
|
||||
name = name.replace("mlp.up_proj", "ffn_up")
|
||||
name = name.replace("self_attn.q_proj", "attn_q")
|
||||
name = name.replace("self_attn.k_proj", "attn_k")
|
||||
name = name.replace("self_attn.v_proj", "attn_v")
|
||||
name = name.replace("self_attn.o_proj", "attn_output")
|
||||
name = name.replace("input_layernorm", "attn_norm")
|
||||
name = name.replace("post_attention_layernorm", "ffn_norm")
|
||||
name = name.replace("model.embed_tokens", "token_embd")
|
||||
name = name.replace("model.norm", "output_norm")
|
||||
name = name.replace("lm_head", "output")
|
||||
return name
|
||||
|
||||
|
||||
def permute_weights(weights, n_head, n_head_kv=None):
|
||||
if n_head_kv is not None and n_head != n_head_kv:
|
||||
n_head = n_head_kv
|
||||
reshaped = weights.reshape(
|
||||
n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
|
||||
)
|
||||
swapped = reshaped.swapaxes(1, 2)
|
||||
final_shape = weights.shape
|
||||
return swapped.reshape(final_shape)
|
||||
|
||||
|
||||
def prepare_metadata(config, vocab):
|
||||
metadata = {
|
||||
"general.name": "llama",
|
||||
"llama.context_length": (
|
||||
mx.array(config["max_position_embeddings"], dtype=mx.uint32)
|
||||
if config.get("max_position_embeddings") is not None
|
||||
else None
|
||||
),
|
||||
"llama.embedding_length": (
|
||||
mx.array(config["hidden_size"], dtype=mx.uint32)
|
||||
if config.get("hidden_size") is not None
|
||||
else None
|
||||
),
|
||||
"llama.block_count": (
|
||||
mx.array(config["num_hidden_layers"], dtype=mx.uint32)
|
||||
if config.get("num_hidden_layers") is not None
|
||||
else None
|
||||
),
|
||||
"llama.feed_forward_length": (
|
||||
mx.array(config["intermediate_size"], dtype=mx.uint32)
|
||||
if config.get("intermediate_size") is not None
|
||||
else None
|
||||
),
|
||||
"llama.rope.dimension_count": (
|
||||
mx.array(
|
||||
config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32
|
||||
)
|
||||
if config.get("hidden_size") is not None
|
||||
and config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.head_count": (
|
||||
mx.array(config["num_attention_heads"], dtype=mx.uint32)
|
||||
if config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.head_count_kv": (
|
||||
mx.array(
|
||||
config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
if config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.expert_count": (
|
||||
mx.array(config.get("num_local_experts", None), dtype=mx.uint32)
|
||||
if config.get("num_local_experts") is not None
|
||||
else None
|
||||
),
|
||||
"llama.expert_used_count": (
|
||||
mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32)
|
||||
if config.get("num_experts_per_tok") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.layer_norm_rms_epsilon": (
|
||||
mx.array(config.get("rms_norm_eps", 1e-05))
|
||||
if config.get("rms_norm_eps") is not None
|
||||
else None
|
||||
),
|
||||
"llama.rope.freq_base": (
|
||||
mx.array(config.get("rope_theta", 10000), dtype=mx.float32)
|
||||
if config.get("rope_theta") is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
rope_scaling = config.get("rope_scaling")
|
||||
if rope_scaling is not None and (typ := rope_scaling.get("type")):
|
||||
rope_factor = rope_scaling.get("factor")
|
||||
f_rope_scale = rope_factor
|
||||
if typ == "linear":
|
||||
rope_scaling_type = "linear"
|
||||
metadata["llama.rope.scaling.type"] = rope_scaling_type
|
||||
metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale)
|
||||
|
||||
metadata["general.file_type"] = mx.array(
|
||||
GGMLFileType.GGML_TYPE_F16.value,
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
metadata["general.quantization_version"] = mx.array(
|
||||
GGMLFileType.GGML_TYPE_F16.value,
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1]
|
||||
metadata["general.architecture"] = "llama"
|
||||
metadata["general.alignment"] = mx.array(32, dtype=mx.uint32)
|
||||
|
||||
# add metadata for vocab
|
||||
metadata["tokenizer.ggml.model"] = "llama"
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype.value)
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
metadata["tokenizer.ggml.tokens"] = tokens
|
||||
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
|
||||
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
|
||||
if vocab.tokenizer.bos_token_id is not None:
|
||||
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
|
||||
vocab.tokenizer.bos_token_id, dtype=mx.uint32
|
||||
)
|
||||
if vocab.tokenizer.eos_token_id is not None:
|
||||
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
|
||||
vocab.tokenizer.eos_token_id, dtype=mx.uint32
|
||||
)
|
||||
if vocab.tokenizer.unk_token_id is not None:
|
||||
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
|
||||
vocab.tokenizer.unk_token_id, dtype=mx.uint32
|
||||
)
|
||||
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
return metadata
|
||||
|
||||
|
||||
def convert_to_gguf(
|
||||
model_path: Union[str, Path],
|
||||
weights: dict,
|
||||
config: dict,
|
||||
output_file_path: str,
|
||||
):
|
||||
if isinstance(model_path, str):
|
||||
model_path = Path(model_path)
|
||||
|
||||
quantization = config.get("quantization", None)
|
||||
if quantization:
|
||||
raise NotImplementedError(
|
||||
"Conversion of quantized models is not yet supported."
|
||||
)
|
||||
print("Converting to GGUF format")
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention
|
||||
weights = {
|
||||
k: (
|
||||
permute_weights(
|
||||
v, config["num_attention_heads"], config["num_attention_heads"]
|
||||
)
|
||||
if "self_attn.q_proj.weight" in k
|
||||
else (
|
||||
permute_weights(
|
||||
v, config["num_attention_heads"], config["num_key_value_heads"]
|
||||
)
|
||||
if "self_attn.k_proj.weight" in k
|
||||
else v
|
||||
)
|
||||
)
|
||||
for k, v in weights.items()
|
||||
}
|
||||
|
||||
# rename weights for gguf format
|
||||
weights = {translate_weight_names(k): v for k, v in weights.items()}
|
||||
|
||||
if not (model_path / "tokenizer.json").exists():
|
||||
raise ValueError("Tokenizer json not found")
|
||||
|
||||
vocab = HfVocab.load(model_path)
|
||||
metadata = prepare_metadata(config, vocab)
|
||||
|
||||
weights = {
|
||||
k: (
|
||||
v.astype(mx.float32).astype(mx.float16)
|
||||
if v.dtype == mx.bfloat16
|
||||
else v.astype(mx.float32) if "norm" in k else v
|
||||
)
|
||||
for k, v in weights.items()
|
||||
}
|
||||
|
||||
output_file_path = output_file_path
|
||||
mx.save_gguf(output_file_path, weights, metadata)
|
||||
print(f"Converted GGUF model saved as: {output_file_path}")
|
||||
@@ -1,299 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import (
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
load_adapters,
|
||||
print_trainable_parameters,
|
||||
)
|
||||
from .utils import load, save_config
|
||||
|
||||
yaml_loader = yaml.SafeLoader
|
||||
yaml_loader.add_implicit_resolver(
|
||||
"tag:yaml.org,2002:float",
|
||||
re.compile(
|
||||
"""^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$""",
|
||||
re.X,
|
||||
),
|
||||
list("-+0123456789."),
|
||||
)
|
||||
|
||||
CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"fine_tune_type": "lora",
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"num_layers": 16,
|
||||
"batch_size": 4,
|
||||
"iters": 1000,
|
||||
"val_batches": 25,
|
||||
"learning_rate": 1e-5,
|
||||
"steps_per_report": 10,
|
||||
"steps_per_eval": 200,
|
||||
"resume_adapter_file": None,
|
||||
"adapter_path": "adapters",
|
||||
"save_every": 100,
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
"max_seq_length": 2048,
|
||||
"config": None,
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
}
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
|
||||
# Training args
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
help=(
|
||||
"Directory with {train, valid, test}.jsonl files or the name "
|
||||
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fine-tune-type",
|
||||
type=str,
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
|
||||
parser.add_argument("--iters", type=int, help="Iterations to train for.")
|
||||
parser.add_argument(
|
||||
"--val-batches",
|
||||
type=int,
|
||||
help="Number of validation batches, -1 uses the entire validation set.",
|
||||
)
|
||||
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
|
||||
parser.add_argument(
|
||||
"--steps-per-report",
|
||||
type=int,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps-per-eval",
|
||||
type=int,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
help="Load path to resume training from the given fine-tuned weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Save/load path for the fine-tuned weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
type=int,
|
||||
help="Save the model every N iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batches",
|
||||
type=int,
|
||||
help="Number of test set batches, -1 uses the entire test set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-seq-length",
|
||||
type=int,
|
||||
help="Maximum sequence length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
type=str,
|
||||
help="A YAML configuration file with the training options",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-checkpoint",
|
||||
action="store_true",
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
|
||||
def train_model(
|
||||
args,
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
train_set,
|
||||
valid_set,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
model.freeze()
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||
l.unfreeze()
|
||||
elif args.fine_tune_type in ["lora", "dora"]:
|
||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
args.num_layers,
|
||||
args.lora_parameters,
|
||||
use_dora=(args.fine_tune_type == "dora"),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
|
||||
|
||||
# Resume from weights if provided
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
print_trainable_parameters(model)
|
||||
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(
|
||||
learning_rate=(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||
model.eval()
|
||||
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
def run(args, training_callback: TrainingCallback = None):
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||
|
||||
if args.test and not args.train:
|
||||
# Allow testing without LoRA layers by providing empty path
|
||||
if args.adapter_path != "":
|
||||
load_adapters(model, args.adapter_path)
|
||||
|
||||
elif args.train:
|
||||
print("Training")
|
||||
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
|
||||
else:
|
||||
raise ValueError("Must provide at least one of --train or --test")
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
evaluate_model(args, model, tokenizer, test_set)
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
config = args.config
|
||||
args = vars(args)
|
||||
if config:
|
||||
print("Loading configuration file", config)
|
||||
with open(config, "r") as file:
|
||||
config = yaml.load(file, yaml_loader)
|
||||
# Prefer parameters from command-line arguments
|
||||
for k, v in config.items():
|
||||
if args.get(k, None) is None:
|
||||
args[k] = v
|
||||
|
||||
# Update defaults for unspecified parameters
|
||||
for k, v in CONFIG_DEFAULTS.items():
|
||||
if args.get(k, None) is None:
|
||||
args[k] = v
|
||||
run(types.SimpleNamespace(**args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,124 +0,0 @@
|
||||
import argparse
|
||||
from typing import List, Union
|
||||
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from transformers.commands.user import tabulate
|
||||
|
||||
|
||||
def ask_for_confirmation(message: str) -> bool:
|
||||
"""Ask user for confirmation with Y/N prompt.
|
||||
Returns True for Y/yes, False for N/no/empty."""
|
||||
y = ("y", "yes", "1")
|
||||
n = ("n", "no", "0", "")
|
||||
full_message = f"{message} (y/n) "
|
||||
while True:
|
||||
answer = input(full_message).lower()
|
||||
if answer in y:
|
||||
return True
|
||||
if answer in n:
|
||||
return False
|
||||
print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MLX Model Cache.")
|
||||
parser.add_argument(
|
||||
"--scan",
|
||||
action="store_true",
|
||||
help="Scan Hugging Face cache for mlx models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
action="store_true",
|
||||
help="Delete models matching the given pattern.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
type=str,
|
||||
help="Model repos contain the pattern.",
|
||||
default="mlx",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.scan:
|
||||
print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
|
||||
hf_cache_info = scan_cache_dir()
|
||||
print(
|
||||
tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
repo.repo_type,
|
||||
"{:>12}".format(repo.size_on_disk_str),
|
||||
repo.nb_files,
|
||||
repo.last_accessed_str,
|
||||
repo.last_modified_str,
|
||||
str(repo.repo_path),
|
||||
]
|
||||
for repo in sorted(
|
||||
hf_cache_info.repos, key=lambda repo: repo.repo_path
|
||||
)
|
||||
if args.pattern in repo.repo_id
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"REPO TYPE",
|
||||
"SIZE ON DISK",
|
||||
"NB FILES",
|
||||
"LAST_ACCESSED",
|
||||
"LAST_MODIFIED",
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if args.delete:
|
||||
print(f'Deleting models matching pattern "{args.pattern}"')
|
||||
hf_cache_info = scan_cache_dir()
|
||||
|
||||
repos = [
|
||||
repo
|
||||
for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
|
||||
if args.pattern in repo.repo_id
|
||||
]
|
||||
if repos:
|
||||
print("\nFound the following models:")
|
||||
print(
|
||||
tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
repo.size_on_disk_str, # Added size information
|
||||
str(repo.repo_path),
|
||||
]
|
||||
for repo in repos
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"SIZE", # Added size header
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
confirmed = ask_for_confirmation(
|
||||
"\nAre you sure you want to delete these models?"
|
||||
)
|
||||
if confirmed:
|
||||
for model_info in repos:
|
||||
print(f"\nDeleting {model_info.repo_id}...")
|
||||
for revision in sorted(
|
||||
model_info.revisions, key=lambda revision: revision.commit_hash
|
||||
):
|
||||
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
|
||||
strategy.execute()
|
||||
print("\nModel(s) deleted successfully.")
|
||||
else:
|
||||
print("\nDeletion cancelled - no changes made.")
|
||||
else:
|
||||
print(f'No models found matching pattern "{args.pattern}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,172 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import yaml
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
save_config,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Merge multiple models.")
|
||||
|
||||
parser.add_argument("--config", type=str, help="Path to the YAML config.")
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_merged_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def slerp(t, w1, w2, eps=1e-5):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
|
||||
Args:
|
||||
t (float): Interpolation weight in [0.0, 1.0]
|
||||
w1 (mx.array): First input
|
||||
w2 (mx.array): Second input
|
||||
eps (float): Constant for numerical stability
|
||||
Returns:
|
||||
mx.array: Interpolated result
|
||||
"""
|
||||
t = float(t)
|
||||
if t == 0:
|
||||
return w1
|
||||
elif t == 1:
|
||||
return w2
|
||||
# Normalize
|
||||
v1 = w1 / mx.linalg.norm(w1)
|
||||
v2 = w2 / mx.linalg.norm(w2)
|
||||
# Angle
|
||||
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
|
||||
theta = mx.arccos(dot)
|
||||
sin_theta = mx.sin(theta + eps)
|
||||
s1 = mx.sin(theta * (1 - t)) / sin_theta
|
||||
s2 = mx.sin(theta * t) / sin_theta
|
||||
return s1 * w1 + s2 * w2
|
||||
|
||||
|
||||
def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
||||
method = config.get("method", None)
|
||||
if method != "slerp":
|
||||
raise ValueError(f"Merge method {method} not supported")
|
||||
|
||||
num_layers = len(model.layers)
|
||||
|
||||
def unpack_values(vals):
|
||||
if isinstance(vals, (int, float)):
|
||||
return np.full(num_layers, vals)
|
||||
bins = len(vals) - 1
|
||||
sizes = [num_layers // bins] * bins
|
||||
sizes[-1] = num_layers - sum(sizes[:-1])
|
||||
return np.concatenate(
|
||||
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
|
||||
)
|
||||
|
||||
param_list = config["parameters"]["t"]
|
||||
params = {}
|
||||
filter_keys = set()
|
||||
for pl in param_list[:-1]:
|
||||
params[pl["filter"]] = unpack_values(pl["value"])
|
||||
filter_keys.add(pl["filter"])
|
||||
default = unpack_values(param_list[-1]["value"])
|
||||
|
||||
for e in range(num_layers):
|
||||
bl = base_model.layers[e]
|
||||
l = model.layers[e]
|
||||
base_weights = bl.parameters()
|
||||
weights = l.parameters()
|
||||
for k, w1 in base_weights.items():
|
||||
w2 = weights[k]
|
||||
t = params.get(k, default)[e]
|
||||
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
|
||||
base_model.update(base_weights)
|
||||
|
||||
|
||||
def merge(
|
||||
config: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
upload_repo: Optional[str] = None,
|
||||
):
|
||||
with open(config, "r") as fid:
|
||||
merge_conf = yaml.safe_load(fid)
|
||||
print("[INFO] Loading")
|
||||
|
||||
model_paths = merge_conf.get("models", [])
|
||||
if len(model_paths) < 2:
|
||||
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
|
||||
|
||||
# Load all models
|
||||
base_hf_path = model_paths[0]
|
||||
base_path = get_model_path(base_hf_path)
|
||||
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||
models = []
|
||||
for mp in model_paths[1:]:
|
||||
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||
base_type = base_config["model_type"]
|
||||
model_type = model_config["model_type"]
|
||||
if base_type != model_type:
|
||||
raise ValueError(
|
||||
f"Can only merge models of the same type,"
|
||||
f" but got {base_type} and {model_type}."
|
||||
)
|
||||
models.append(model)
|
||||
|
||||
# Merge models into base model
|
||||
for m in models:
|
||||
merge_models(base_model, m, merge_conf)
|
||||
|
||||
# Save base model
|
||||
mlx_path = Path(mlx_path)
|
||||
weights = dict(tree_flatten(base_model.parameters()))
|
||||
del models, base_model
|
||||
save_weights(mlx_path, weights, donate_weights=True)
|
||||
py_files = glob.glob(str(base_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, mlx_path)
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
save_config(config, config_path=mlx_path / "config.json")
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
merge(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,121 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
from .cache import QuantizedKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArgs:
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
N: int,
|
||||
offset: int = 0,
|
||||
window_size: Optional[int] = None,
|
||||
lengths: Optional[mx.array] = None,
|
||||
):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
linds = linds[:, None]
|
||||
rinds = rinds[None]
|
||||
mask = linds < rinds
|
||||
if window_size is not None:
|
||||
mask = mask | (linds > rinds + window_size)
|
||||
if lengths is not None:
|
||||
lengths = lengths[:, None, None, None]
|
||||
mask = mask | (rinds >= lengths)
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
window_size = None
|
||||
offset = 0
|
||||
if cache is not None and cache[0] is not None:
|
||||
c = cache[0]
|
||||
if hasattr(c, "max_size"):
|
||||
offset = min(c.max_size, c.offset)
|
||||
window_size = c.max_size
|
||||
else:
|
||||
offset = c.offset
|
||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||
mask = mask.astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
return mask
|
||||
|
||||
|
||||
def quantized_scaled_dot_product_attention(
|
||||
queries: mx.array,
|
||||
q_keys: tuple[mx.array, mx.array, mx.array],
|
||||
q_values: tuple[mx.array, mx.array, mx.array],
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
group_size: int = 64,
|
||||
bits: int = 8,
|
||||
) -> mx.array:
|
||||
B, n_q_heads, L, D = queries.shape
|
||||
n_kv_heads = q_keys[0].shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
queries *= scale
|
||||
|
||||
if n_repeats > 1:
|
||||
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
||||
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
||||
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
||||
|
||||
scores = mx.quantized_matmul(
|
||||
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
|
||||
)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
out = mx.quantized_matmul(
|
||||
scores, *q_values, transpose=False, group_size=group_size, bits=bits
|
||||
)
|
||||
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, (B, n_q_heads, L, D))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache,
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
) -> mx.array:
|
||||
if isinstance(cache, QuantizedKVCache):
|
||||
return quantized_scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
group_size=cache.group_size,
|
||||
bits=cache.bits,
|
||||
)
|
||||
else:
|
||||
return mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=scale, mask=mask
|
||||
)
|
||||
@@ -1,438 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(
|
||||
model: nn.Module,
|
||||
max_kv_size: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
|
||||
This function will defer the cache construction to the model if it has a
|
||||
``make_cache`` method, otherwise it will make a default KV cache.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
"""
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
num_layers = len(model.layers)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
return [KVCache() for _ in range(num_layers)]
|
||||
|
||||
|
||||
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||
"""
|
||||
Save a pre-computed prompt cache to a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
cache (List[Any]): The model state.
|
||||
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||
state.
|
||||
"""
|
||||
cache_data = [c.state for c in cache]
|
||||
cache_info = [c.meta_state for c in cache]
|
||||
cache_data = dict(tree_flatten(cache_data))
|
||||
cache_classes = [type(c).__name__ for c in cache]
|
||||
cache_metadata = [cache_info, metadata, cache_classes]
|
||||
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
|
||||
def load_prompt_cache(file_name, return_metadata=False):
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
return_metadata (bool): Whether or not to return metadata.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
the metadata if requested.
|
||||
"""
|
||||
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
||||
arrays = tree_unflatten(list(arrays.items()))
|
||||
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
||||
info, metadata, classes = cache_metadata
|
||||
cache = [globals()[c]() for c in classes]
|
||||
for c, state, meta_state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
c.meta_state = meta_state
|
||||
if return_metadata:
|
||||
return cache, metadata
|
||||
return cache
|
||||
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
return all(c.is_trimmable() for c in cache)
|
||||
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
|
||||
This function will trim the cache if possible (in-place) and return the
|
||||
number of tokens that were trimmed.
|
||||
|
||||
Args:
|
||||
cache (List[Any]): The model's cache.
|
||||
num_tokens (int): The number of tokens to trim.
|
||||
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
||||
return 0
|
||||
return [c.trim(num_tokens) for c in cache][0]
|
||||
|
||||
|
||||
class _BaseCache:
|
||||
@property
|
||||
def state(self):
|
||||
return []
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
if v is not None and v:
|
||||
raise ValueError("This cache has no state but a state was set.")
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return ""
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
if v is not None and v:
|
||||
raise ValueError("This cache has no meta_state but a meta_state was set.")
|
||||
|
||||
def is_trimmable(self):
|
||||
return False
|
||||
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
def __init__(self, group_size: int = 64, bits: int = 8):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||
v_head_dim = values.shape[-1]
|
||||
prev = self.offset
|
||||
|
||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
||||
el_per_int = 8 * mx.uint32.size // self.bits
|
||||
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
||||
shape = (B, n_kv_heads, new_steps)
|
||||
|
||||
def init_quant(dim):
|
||||
return (
|
||||
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
|
||||
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||
)
|
||||
|
||||
def expand_quant(x):
|
||||
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
|
||||
return mx.concatenate([x, new_x], axis=-2)
|
||||
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys, self.values = tree_map(
|
||||
lambda x: x[..., :prev, :], (self.keys, self.values)
|
||||
)
|
||||
|
||||
self.keys, self.values = tree_map(
|
||||
expand_quant, (self.keys, self.values)
|
||||
)
|
||||
else:
|
||||
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
|
||||
|
||||
self.offset += num_steps
|
||||
|
||||
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||
for i in range(len(self.keys)):
|
||||
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||
self.values[i][..., prev : self.offset, :] = values[i]
|
||||
|
||||
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset == self.keys[0].shape[2]:
|
||||
return self.keys, self.values
|
||||
else:
|
||||
return tree_map(
|
||||
lambda x: x[..., : self.offset, :], (self.keys, self.values)
|
||||
)
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
self.step, self.offset, self.group_size, self.bits = map(int, v)
|
||||
|
||||
def is_trimmable(self):
|
||||
return True
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
B, n_kv_heads, _, k_head_dim = keys.shape
|
||||
v_head_dim = values.shape[3]
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
|
||||
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
self.keys[..., prev : self.offset, :] = keys
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset == self.keys.shape[2]:
|
||||
return self.keys, self.values
|
||||
else:
|
||||
return (
|
||||
self.keys[..., : self.offset, :],
|
||||
self.values[..., : self.offset, :],
|
||||
)
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
self.offset = self.keys.shape[2]
|
||||
|
||||
def is_trimmable(self):
|
||||
return True
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||
quant_cache.offset = self.offset
|
||||
if self.keys is not None:
|
||||
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
||||
quant_cache.values = mx.quantize(
|
||||
self.values, group_size=group_size, bits=bits
|
||||
)
|
||||
return quant_cache
|
||||
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
|
||||
def __init__(self, max_size=None, keep=0, step=256):
|
||||
self.keep = keep
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
to_cat = []
|
||||
if trim_size > 0:
|
||||
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||
else:
|
||||
to_cat = [v]
|
||||
if append is not None:
|
||||
to_cat.append(append)
|
||||
return mx.concatenate(to_cat, axis=2)
|
||||
|
||||
def _temporal_order(self, v):
|
||||
"""
|
||||
Rearrange the cache into temporal order, slicing off the end if unused.
|
||||
"""
|
||||
if self._idx == v.shape[2]:
|
||||
return v
|
||||
elif self._idx < self.offset:
|
||||
return mx.concatenate(
|
||||
[
|
||||
v[..., : self.keep, :],
|
||||
v[..., self._idx :, :],
|
||||
v[..., self.keep : self._idx, :],
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
else:
|
||||
return v[..., : self._idx, :]
|
||||
|
||||
def _update_concat(self, keys, values):
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
# Put the keys/values in temporal order to
|
||||
# preserve context
|
||||
self.keys = self._temporal_order(self.keys)
|
||||
self.values = self._temporal_order(self.values)
|
||||
|
||||
# The largest size is self.max_size + S to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self._idx - self.max_size
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += keys.shape[2]
|
||||
self._idx = self.keys.shape[2]
|
||||
return self.keys, self.values
|
||||
|
||||
def _update_in_place(self, keys, values):
|
||||
# May not have hit the max size yet, so potentially
|
||||
# keep growing the cache
|
||||
B, n_kv_heads, S, k_head_dim = keys.shape
|
||||
prev = self.offset
|
||||
if self.keys is None or (
|
||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||
):
|
||||
v_head_dim = values.shape[3]
|
||||
new_size = min(self.step, self.max_size - prev)
|
||||
k_shape = (B, n_kv_heads, new_size, k_head_dim)
|
||||
v_shape = (B, n_kv_heads, new_size, v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
self._idx = prev
|
||||
|
||||
# Trim if needed
|
||||
trim_size = self.keys.shape[2] - self.max_size
|
||||
if trim_size > 0:
|
||||
self.keys = self._trim(trim_size, self.keys)
|
||||
self.values = self._trim(trim_size, self.values)
|
||||
self._idx = self.max_size
|
||||
|
||||
# Rotate
|
||||
if self._idx == self.max_size:
|
||||
self._idx = self.keep
|
||||
|
||||
# Assign
|
||||
self.keys[..., self._idx : self._idx + S, :] = keys
|
||||
self.values[..., self._idx : self._idx + S, :] = values
|
||||
self.offset += S
|
||||
self._idx += S
|
||||
|
||||
# If the buffer is not full, slice off the end
|
||||
if self.offset < self.max_size:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
return self.keys, self.values
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
if keys.shape[2] == 1:
|
||||
return self._update_in_place(keys, values)
|
||||
return self._update_concat(keys, values)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset < self.keys.shape[2]:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
else:
|
||||
return self.keys, self.values
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return tuple(
|
||||
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
||||
)
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
||||
int,
|
||||
v,
|
||||
)
|
||||
|
||||
def is_trimmable(self):
|
||||
return self.offset < self.max_size
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
self._idx -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
raise NotImplementedError("RotatingKVCache Quantization NYI")
|
||||
|
||||
|
||||
class MambaCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.cache = [None, None]
|
||||
|
||||
def __setitem__(self, idx, value):
|
||||
self.cache[idx] = value
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.cache[idx]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.cache = v
|
||||
@@ -1,195 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 8192
|
||||
num_hidden_layers: int = 40
|
||||
intermediate_size: int = 22528
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 64
|
||||
rope_theta: float = 8000000.0
|
||||
vocab_size: int = 256000
|
||||
layer_norm_eps: float = 1e-05
|
||||
logit_scale: float = 0.0625
|
||||
attention_bias: bool = False
|
||||
layer_norm_bias: bool = False
|
||||
use_qk_norm: bool = False
|
||||
|
||||
|
||||
class LayerNorm2D(nn.Module):
|
||||
|
||||
def __init__(self, d1, d2, eps):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((d1, d2))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // args.num_attention_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
attetion_bias = args.attention_bias
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||
|
||||
self.use_qk_norm = args.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
|
||||
self.k_norm = LayerNorm2D(
|
||||
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.n_heads, -1)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1)
|
||||
if self.use_qk_norm:
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = CohereModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,206 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .cache import KVCache, RotatingKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 4096
|
||||
head_dim: int = 128
|
||||
num_hidden_layers: int = 32
|
||||
intermediate_size: int = 14336
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
rope_theta: float = 50000.0
|
||||
vocab_size: int = 256000
|
||||
layer_norm_eps: float = 1e-05
|
||||
logit_scale: float = 0.0625
|
||||
attention_bias: bool = False
|
||||
layer_norm_bias: bool = False
|
||||
sliding_window: int = 4096
|
||||
sliding_window_pattern: int = 4
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
if (head_dim * n_heads) != dim:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
|
||||
f" and `num_heads`: {n_heads})."
|
||||
)
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
attetion_bias = args.attention_bias
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Apply RoPE only if sliding window is enabled
|
||||
if self.use_sliding_window:
|
||||
if cache is None:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
else:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
if self.use_sliding_window and mask is not None:
|
||||
key_len = keys.shape[-2]
|
||||
if mask.shape[-1] != key_len:
|
||||
mask = mask[..., -key_len:]
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args, layer_idx)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args, layer_idx=i)
|
||||
for i in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
if mask is None:
|
||||
j = self.args.sliding_window_pattern
|
||||
mask = create_attention_mask(h, cache[j - 1 : j])
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = CohereModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
def make_cache(self):
|
||||
caches = []
|
||||
for i in range(self.args.num_hidden_layers):
|
||||
if (
|
||||
i % self.args.sliding_window_pattern
|
||||
== self.args.sliding_window_pattern - 1
|
||||
):
|
||||
caches.append(KVCache())
|
||||
else:
|
||||
caches.append(
|
||||
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
|
||||
)
|
||||
return caches
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,254 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
d_model: int
|
||||
ffn_config: dict
|
||||
attn_config: dict
|
||||
n_layers: int
|
||||
n_heads: int
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_heads = args.n_heads
|
||||
self.d_model = args.d_model
|
||||
self.head_dim = args.d_model // args.n_heads
|
||||
self.num_key_value_heads = args.attn_config["kv_n_heads"]
|
||||
self.clip_qkv = args.attn_config["clip_qkv"]
|
||||
self.rope_theta = args.attn_config["rope_theta"]
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.Wqkv = nn.Linear(
|
||||
args.d_model,
|
||||
(self.num_key_value_heads * 2 + self.num_heads) * self.head_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.out_proj = nn.Linear(args.d_model, args.d_model, bias=False)
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = mx.clip(qkv, a_min=-self.clip_qkv, a_max=self.clip_qkv)
|
||||
splits = [self.d_model, self.d_model + self.head_dim * self.num_key_value_heads]
|
||||
queries, keys, values = mx.split(qkv, splits, axis=-1)
|
||||
|
||||
B, L, D = x.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class NormAttnNorm(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.norm_1 = nn.LayerNorm(args.d_model, bias=False)
|
||||
self.norm_2 = nn.LayerNorm(args.d_model, bias=False)
|
||||
self.attn = Attention(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
||||
x = h + x
|
||||
return x, self.norm_2(x)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, d_model: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.v1 = nn.Linear(d_model, ffn_dim, bias=False)
|
||||
self.w1 = nn.Linear(d_model, ffn_dim, bias=False)
|
||||
self.w2 = nn.Linear(ffn_dim, d_model, bias=False)
|
||||
self.act_fn = nn.silu
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
current_hidden_states = self.act_fn(self.w1(x)) * self.v1(x)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class Router(nn.Module):
|
||||
def __init__(self, d_model: int, num_experts: int):
|
||||
super().__init__()
|
||||
self.layer = nn.Linear(d_model, num_experts, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class SparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.d_model = args.d_model
|
||||
self.ffn_dim = args.ffn_config["ffn_hidden_size"]
|
||||
self.num_experts = args.ffn_config["moe_num_experts"]
|
||||
self.num_experts_per_tok = args.ffn_config["moe_top_k"]
|
||||
|
||||
self.router = Router(self.d_model, self.num_experts)
|
||||
self.experts = [
|
||||
MLP(self.d_model, self.ffn_dim) for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.router(x)
|
||||
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
|
||||
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True)
|
||||
scores = scores.astype(x.dtype)
|
||||
|
||||
if self.training:
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt)
|
||||
y = mx.stack(y, axis=0)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.ffn = SparseMoeBlock(args)
|
||||
self.norm_attn_norm = NormAttnNorm(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r, h = self.norm_attn_norm(x, mask, cache)
|
||||
out = self.ffn(h) + r
|
||||
return out
|
||||
|
||||
|
||||
class DBRX(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = args.vocab_size
|
||||
self.wte = nn.Embedding(args.vocab_size, args.d_model)
|
||||
self.blocks = [DecoderLayer(args=args) for _ in range(args.n_layers)]
|
||||
self.norm_f = nn.LayerNorm(args.d_model, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
||||
for layer, c in zip(self.blocks, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm_f(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.transformer = DBRX(args)
|
||||
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.blocks
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Split experts into sub matrices
|
||||
num_experts = self.args.ffn_config["moe_num_experts"]
|
||||
dim = self.args.ffn_config["ffn_hidden_size"]
|
||||
|
||||
pattern = "experts.mlp"
|
||||
new_weights = {k: v for k, v in weights.items() if pattern not in k}
|
||||
for k, v in weights.items():
|
||||
if pattern in k:
|
||||
experts = [
|
||||
(k.replace(".mlp", f".{e}") + ".weight", sv)
|
||||
for e, sv in enumerate(mx.split(v, num_experts, axis=0))
|
||||
]
|
||||
if k.endswith("w2"):
|
||||
experts = [(s, sv.T) for s, sv in experts]
|
||||
new_weights.update(experts)
|
||||
return new_weights
|
||||
@@ -1,261 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Optional[Dict] = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
class DeepseekAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
attention_bias = getattr(config, "attention_bias", False)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.num_attention_heads * self.head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.num_attention_heads * self.head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
|
||||
rope_scale = 1.0
|
||||
if config.rope_scaling and config.rope_scaling["type"] == "linear":
|
||||
assert isinstance(config.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / config.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
base=config.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = hidden_size or config.hidden_size
|
||||
self.intermediate_size = intermediate_size or config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = nn.silu
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
scores = mx.softmax(gates, axis=-1, precise=True)
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekMoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekMLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekAttention(config)
|
||||
self.mlp = (
|
||||
DeepseekMoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekMLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class DeepseekModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for m in ["gate_proj", "down_proj", "up_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,460 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek_v2"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
routed_scaling_factor: float = 1.0
|
||||
kv_lora_rank: int = 512
|
||||
q_lora_rank: int = 1536
|
||||
qk_rope_head_dim: int = 64
|
||||
v_head_dim: int = 128
|
||||
qk_nope_head_dim: int = 128
|
||||
topk_method: str = "gready"
|
||||
n_group: Optional[int] = None
|
||||
topk_group: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Dict = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
def yarn_find_correction_dim(
|
||||
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
|
||||
def yarn_find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||
if min_val == max_val:
|
||||
max_val += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
|
||||
return mx.clip(linear_func, 0, 1)
|
||||
|
||||
|
||||
class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
scaling_factor=1.0,
|
||||
original_max_position_embeddings=4096,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
mscale=1,
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
freq_inter = scaling_factor * base ** (
|
||||
mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
)
|
||||
low, high = yarn_find_correction_range(
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
dim,
|
||||
base,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
self._freqs = (freq_inter * freq_extra) / (
|
||||
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||
)
|
||||
|
||||
def __call__(self, x, offset=0):
|
||||
if self.mscale != 1.0:
|
||||
x = self.mscale * x
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
x.shape[-1],
|
||||
traditional=True,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
|
||||
self.scale = self.q_head_dim**-0.5
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads
|
||||
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV2YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(x)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
|
||||
|
||||
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
|
||||
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
|
||||
compressed_kv = self.kv_a_proj_with_mqa(x)
|
||||
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
|
||||
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
)
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
|
||||
scores = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
if self.topk_method == "group_limited_greedy":
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = scores.max(axis=-1)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV2Attention(config)
|
||||
self.mlp = (
|
||||
DeepseekV2MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV2MLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekV2DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.start_idx = 0
|
||||
self.end_idx = len(self.layers)
|
||||
self.num_layers = self.end_idx
|
||||
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
|
||||
def pipeline(self, group):
|
||||
# Split layers in reverse so rank=0 gets the last layers and
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.num_layers = layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
self.num_layers = len(self.layers) - self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
# Hack to avoid time-outs during prompt-processing
|
||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * self.num_layers
|
||||
|
||||
# Receive from the previous process in the pipeline
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(
|
||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||
)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV2Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||
@@ -1,478 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek_v3"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
routed_scaling_factor: float = 1.0
|
||||
kv_lora_rank: int = 512
|
||||
q_lora_rank: int = 1536
|
||||
qk_rope_head_dim: int = 64
|
||||
v_head_dim: int = 128
|
||||
qk_nope_head_dim: int = 128
|
||||
topk_method: str = "noaux_tc"
|
||||
scoring_func: str = "sigmoid"
|
||||
norm_topk_prob: bool = True
|
||||
n_group: Optional[int] = None
|
||||
topk_group: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Dict = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
def yarn_find_correction_dim(
|
||||
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
|
||||
def yarn_find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||
if min_val == max_val:
|
||||
max_val += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
|
||||
return mx.clip(linear_func, 0, 1)
|
||||
|
||||
|
||||
class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
scaling_factor=1.0,
|
||||
original_max_position_embeddings=4096,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
mscale=1,
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
freq_inter = scaling_factor * base ** (
|
||||
mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
)
|
||||
low, high = yarn_find_correction_range(
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
dim,
|
||||
base,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
self._freqs = (freq_inter * freq_extra) / (
|
||||
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||
)
|
||||
|
||||
def __call__(self, x, offset=0):
|
||||
if self.mscale != 1.0:
|
||||
x = self.mscale * x
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
x.shape[-1],
|
||||
traditional=True,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
# A clipped silu to prevent fp16 from overflowing
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def clipped_silu(x):
|
||||
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
|
||||
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
|
||||
self.scale = self.q_head_dim**-0.5
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads
|
||||
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV3YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(x)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
|
||||
|
||||
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
|
||||
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
|
||||
compressed_kv = self.kv_a_proj_with_mqa(x)
|
||||
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
|
||||
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
)
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
|
||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||
|
||||
assert self.topk_method == "noaux_tc", "Unsupported topk method."
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores + self.e_score_correction_bias
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||
scores = scores / denominator
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size,
|
||||
config.moe_intermediate_size,
|
||||
config.n_routed_experts,
|
||||
activation=clipped_silu,
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
self.mlp = (
|
||||
DeepseekV3MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV3MLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
return h + r
|
||||
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekV3DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.start_idx = 0
|
||||
self.end_idx = len(self.layers)
|
||||
self.num_layers = self.end_idx
|
||||
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
|
||||
def pipeline(self, group):
|
||||
# Split layers in reverse so rank=0 gets the last layers and
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
self.num_layers = len(self.layers) - self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
# Hack to avoid time-outs during prompt-processing
|
||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * self.num_layers
|
||||
|
||||
# Receive from the previous process in the pipeline
|
||||
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(
|
||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||
)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV3Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
# Remove multi-token prediction layer
|
||||
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||
@@ -1,166 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
vocab_size: int
|
||||
rope_theta: float
|
||||
layer_norm_epsilon: float
|
||||
num_key_value_heads: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.head_dim = head_dim = args.head_dim or (dim // n_heads)
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
args.rope_traditional,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
q = self.rope(q, offset=cache.offset)
|
||||
k = self.rope(k, offset=cache.offset)
|
||||
k, v = cache.update_and_fetch(k, v)
|
||||
else:
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
out = scaled_dot_product_attention(
|
||||
q, k, v, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.attention = AttentionModule(args)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
|
||||
self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = x + self.attn.attention(self.ln_1(x), mask, cache)
|
||||
out = h + self.mlp(self.ln_2(h))
|
||||
return out
|
||||
|
||||
|
||||
class ExaoneModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_layers)]
|
||||
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.ln_f(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = ExaoneModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
@@ -1,178 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
head_dim: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = GemmaModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,203 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
head_dim: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
attn_logit_softcapping: float = 50.0
|
||||
final_logit_softcapping: float = 30.0
|
||||
query_pre_attn_scalar: float = 144.0
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
|
||||
self.scale = 1.0 / (args.query_pre_attn_scalar**0.5)
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.attn_logit_softcapping = args.attn_logit_softcapping
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries * self.scale
|
||||
|
||||
if self.repeats > 1:
|
||||
queries = queries.reshape(
|
||||
B, self.n_kv_heads, self.repeats, L, self.head_dim
|
||||
)
|
||||
keys = mx.expand_dims(keys, 2)
|
||||
values = mx.expand_dims(values, 2)
|
||||
|
||||
scores = queries @ keys.swapaxes(-1, -2)
|
||||
scores = mx.tanh(scores / self.attn_logit_softcapping)
|
||||
scores *= self.attn_logit_softcapping
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, precise=True, axis=-1)
|
||||
output = scores @ values
|
||||
if self.repeats > 1:
|
||||
output = output.reshape(B, self.n_heads, L, self.head_dim)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + self.post_attention_layernorm(r)
|
||||
r = self.mlp(self.pre_feedforward_layernorm(h))
|
||||
out = h + self.post_feedforward_layernorm(r)
|
||||
return out
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.final_logit_softcapping = args.final_logit_softcapping
|
||||
self.model = GemmaModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = mx.tanh(out / self.final_logit_softcapping)
|
||||
out = out * self.final_logit_softcapping
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,201 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
n_ctx: int
|
||||
n_embd: int
|
||||
n_head: int
|
||||
n_layer: int
|
||||
n_positions: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.n_head
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
|
||||
|
||||
self.n_embd = args.n_embd
|
||||
self.n_head = args.n_head
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.c_attn(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.n_embd = args.n_embd
|
||||
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
|
||||
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = args.n_head
|
||||
self.n_embd = args.n_embd
|
||||
self.layer_norm_epsilon = args.layer_norm_epsilon
|
||||
self.attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.ln_1 = nn.LayerNorm(
|
||||
self.n_embd,
|
||||
eps=self.layer_norm_epsilon,
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.ln_2(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_embd = args.n_embd
|
||||
self.n_positions = args.n_positions
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layer = args.n_layer
|
||||
self.layer_norm_epsilon = args.layer_norm_epsilon
|
||||
assert self.vocab_size > 0
|
||||
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
|
||||
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
|
||||
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
_, L = inputs.shape
|
||||
|
||||
hidden_states = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if hidden_states.shape[1] > 1:
|
||||
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
return self.ln_f(hidden_states)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GPT2Model(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.wte.as_linear(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for i in range(self.args.n_layer):
|
||||
if f"h.{i}.attn.bias" in weights:
|
||||
del weights[f"h.{i}.attn.bias"]
|
||||
if f"h.{i}.attn.c_attn.weight" in weights:
|
||||
weights[f"h.{i}.attn.c_attn.weight"] = weights[
|
||||
f"h.{i}.attn.c_attn.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.attn.c_proj.weight" in weights:
|
||||
weights[f"h.{i}.attn.c_proj.weight"] = weights[
|
||||
f"h.{i}.attn.c_proj.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.mlp.c_fc.weight" in weights:
|
||||
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
|
||||
f"h.{i}.mlp.c_fc.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.mlp.c_proj.weight" in weights:
|
||||
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
|
||||
f"h.{i}.mlp.c_proj.weight"
|
||||
].transpose(1, 0)
|
||||
for weight in weights:
|
||||
if not weight.startswith("model."):
|
||||
new_weights[f"model.{weight}"] = weights[weight]
|
||||
else:
|
||||
new_weights[weight] = weights[weight]
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
@@ -1,189 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
n_embd: int
|
||||
n_layer: int
|
||||
n_inner: int
|
||||
n_head: int
|
||||
n_positions: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
multi_query: bool = True
|
||||
attention_bias: bool = True
|
||||
mlp_bias: bool = True
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = 1 if self.multi_query else self.n_head
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim = args.n_embd
|
||||
self.n_heads = n_heads = args.n_head
|
||||
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
|
||||
|
||||
self.head_dim = head_dim = dim // n_heads
|
||||
|
||||
self.kv_dim = n_kv_heads * head_dim
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
|
||||
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.c_attn(x)
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.n_embd
|
||||
hidden_dim = args.n_inner
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_head = args.n_head
|
||||
self.n_embd = args.n_embd
|
||||
self.attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.ln_2(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
assert self.vocab_size > 0
|
||||
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
|
||||
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
B, L = inputs.shape
|
||||
|
||||
hidden_states = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if mask is not None and hidden_states.shape[1] > 1:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
position_ids = mx.array(np.arange(L))
|
||||
else:
|
||||
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
|
||||
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
return self.ln_f(hidden_states)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = GPTBigCodeModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
@@ -1,219 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
# Based on the transformers implementation at:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
max_position_embeddings: int
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
layer_norm_eps: float
|
||||
vocab_size: int
|
||||
rotary_emb_base: int
|
||||
rotary_pct: float
|
||||
num_key_value_heads: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
args.hidden_size % args.num_attention_heads == 0
|
||||
), "hidden_size must be divisible by num_attention_heads"
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
dims=int(self.head_dim * args.rotary_pct),
|
||||
traditional=False,
|
||||
base=args.rotary_emb_base,
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = nn.Linear(
|
||||
self.hidden_size, 3 * self.hidden_size, bias=True
|
||||
)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.query_key_value(x)
|
||||
|
||||
new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim)
|
||||
qkv = qkv.reshape(*new_qkv_shape)
|
||||
|
||||
queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)]
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
|
||||
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
# gelu_approx corresponds to FastGELUActivation in transformers.
|
||||
return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.layer_norm_eps = args.layer_norm_eps
|
||||
self.attention = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
self.hidden_size,
|
||||
eps=self.layer_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
self.hidden_size, eps=self.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
# NeoX runs attention and feedforward network in parallel.
|
||||
attn = self.attention(self.input_layernorm(x), mask, cache)
|
||||
ffn = self.mlp(self.post_attention_layernorm(x))
|
||||
out = attn + ffn + residual
|
||||
return out
|
||||
|
||||
|
||||
class GPTNeoXModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
self.layer_norm_eps = args.layer_norm_eps
|
||||
assert self.vocab_size > 0
|
||||
self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size)
|
||||
self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)]
|
||||
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
_, L = inputs.shape
|
||||
|
||||
hidden_states = self.embed_in(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
out = self.final_layer_norm(hidden_states)
|
||||
out = self.embed_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GPTNeoXModel(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
|
||||
for w_key, w_value in weights.items():
|
||||
# Created through register_buffer in Pytorch, not needed here.
|
||||
ignore_suffixes = [
|
||||
".attention.bias",
|
||||
".attention.masked_bias",
|
||||
".attention.rotary_emb.inv_freq",
|
||||
]
|
||||
|
||||
skip_weight = False
|
||||
for ignored_suffix in ignore_suffixes:
|
||||
if w_key.endswith(ignored_suffix):
|
||||
skip_weight = True
|
||||
break
|
||||
|
||||
if skip_weight:
|
||||
continue
|
||||
|
||||
if not w_key.startswith("model."):
|
||||
w_key = f"model.{w_key}"
|
||||
|
||||
w_key = w_key.replace(".gpt_neox.layers.", ".h.")
|
||||
w_key = w_key.replace(".gpt_neox.", ".")
|
||||
|
||||
new_weights[w_key] = w_value
|
||||
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
@@ -1,185 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
attention_bias: bool
|
||||
head_dim: int
|
||||
max_position_embeddings: int
|
||||
mlp_bias: bool
|
||||
model_type: str
|
||||
rope_theta: float
|
||||
tie_word_embeddings: bool
|
||||
|
||||
|
||||
class HeliumAttention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class HeliumMLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.intermediate_size = args.intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
||||
)
|
||||
self.up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
||||
)
|
||||
self.down_proj = nn.Linear(
|
||||
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class HeliumDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = HeliumAttention(args)
|
||||
self.mlp = HeliumMLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class HeliumModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
self.vocab_size = args.vocab_size
|
||||
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
|
||||
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
|
||||
self.model = HeliumModel(args)
|
||||
|
||||
self.vocab_size = args.vocab_size
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,294 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
attention_bias: bool
|
||||
moe_topk: int
|
||||
num_experts: int
|
||||
num_shared_expert: int
|
||||
use_mixed_mlp_moe: bool
|
||||
use_qk_norm: bool
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
use_cla: bool
|
||||
cla_share_factor: 2
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
|
||||
class DynamicNTKAlphaRoPE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
base: float = 10000,
|
||||
scaling_alpha: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
base = base * scaling_alpha ** (dims / (dims - 2))
|
||||
self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, kv_proj: bool, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
if kv_proj:
|
||||
self.k_proj = nn.Linear(
|
||||
dim, n_kv_heads * head_dim, bias=args.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
dim, n_kv_heads * head_dim, bias=args.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
||||
self.use_qk_norm = args.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
|
||||
self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
|
||||
|
||||
self.rope = DynamicNTKAlphaRoPE(
|
||||
head_dim,
|
||||
base=args.rope_theta,
|
||||
scaling_alpha=args.rope_scaling["alpha"],
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
kv_states=None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries = self.q_proj(x)
|
||||
|
||||
if kv_states is None:
|
||||
keys, values = self.k_proj(x), self.v_proj(x)
|
||||
kv_states = keys, values
|
||||
else:
|
||||
keys, values = kv_states
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
offset = cache.offset if cache else 0
|
||||
queries = self.rope(queries, offset=offset)
|
||||
keys = self.rope(keys, offset=offset)
|
||||
if self.use_qk_norm:
|
||||
queries = self.query_layernorm(queries)
|
||||
keys = self.key_layernorm(keys)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), kv_states
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Gate(nn.Module):
|
||||
def __init__(self, dim, num_experts):
|
||||
super().__init__()
|
||||
self.wg = nn.Linear(dim, num_experts, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.wg(x)
|
||||
|
||||
|
||||
class MoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
intermediate_size = args.intermediate_size
|
||||
self.use_shared_mlp = args.use_mixed_mlp_moe
|
||||
|
||||
if args.use_mixed_mlp_moe:
|
||||
self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert)
|
||||
|
||||
self.num_experts = num_experts = args.num_experts
|
||||
self.top_k = args.moe_topk
|
||||
|
||||
self.gate = Gate(dim, num_experts)
|
||||
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
):
|
||||
gates = self.gate(x)
|
||||
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
if self.use_shared_mlp:
|
||||
shared_expert_output = self.shared_mlp(x)
|
||||
y = y + shared_expert_output
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs, kv_proj: bool):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(kv_proj, args)
|
||||
self.mlp = MoeBlock(args)
|
||||
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
):
|
||||
r, shared_kv_states = self.self_attn(
|
||||
self.input_layernorm(x), mask, cache, shared_kv_states
|
||||
)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, shared_kv_states
|
||||
|
||||
|
||||
class HunYuanModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0)
|
||||
for i in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||
if i % self.args.cla_share_factor == 0:
|
||||
shared_kv_states = None
|
||||
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = HunYuanModel(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.model.embed_tokens.as_linear(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,241 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
bias: bool = True
|
||||
max_position_embeddings: int = 32768
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
|
||||
)
|
||||
|
||||
|
||||
class DynamicNTKScalingRoPE(nn.Module):
|
||||
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_base = base
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.scale = scale
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.original_base * (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
else:
|
||||
base = self.original_base
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.wqkv = nn.Linear(
|
||||
dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=args.bias
|
||||
)
|
||||
self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 2.0
|
||||
)
|
||||
|
||||
self.rope = DynamicNTKScalingRoPE(
|
||||
head_dim,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv_states = self.wqkv(x)
|
||||
qkv_states = qkv_states.reshape(B, L, -1, 2 + self.n_kv_groups, self.head_dim)
|
||||
|
||||
queries = qkv_states[..., : self.n_kv_groups, :]
|
||||
queries = queries.reshape(B, L, -1, self.head_dim)
|
||||
keys = qkv_states[..., -2, :]
|
||||
values = qkv_states[..., -1, :]
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.attention_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class InternLM2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
assert args.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = InternLM2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.tok_embeddings.as_linear(out)
|
||||
else:
|
||||
out = self.output(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,241 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
bias: bool = False
|
||||
qkv_bias: bool = False
|
||||
max_position_embeddings: int = 32768
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "rope_type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic"
|
||||
)
|
||||
|
||||
|
||||
class DynamicNTKScalingRoPE(nn.Module):
|
||||
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_base = base
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.scale = scale
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.original_base * (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
else:
|
||||
base = self.original_base
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
qkv_bias = args.qkv_bias
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None
|
||||
and args.rope_scaling["rope_type"] == "linear"
|
||||
else 2.0
|
||||
)
|
||||
|
||||
self.rope = DynamicNTKScalingRoPE(
|
||||
head_dim,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, bias):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class InternLM2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
assert args.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = InternLM2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,205 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
args.rope_traditional,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,242 +0,0 @@
|
||||
# Copyright © 2024-2025 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .cache import MambaCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
state_size: int
|
||||
num_hidden_layers: int
|
||||
conv_kernel: int
|
||||
use_bias: bool
|
||||
use_conv_bias: bool
|
||||
time_step_rank: int
|
||||
tie_word_embeddings: bool = True
|
||||
use_bcdt_rms: bool = False
|
||||
mixer_rms_eps: float = 1e-6
|
||||
|
||||
def __post_init__(self):
|
||||
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
||||
self.hidden_size = self.d_model
|
||||
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
|
||||
self.intermediate_size = self.d_inner
|
||||
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
|
||||
self.state_size = self.d_state
|
||||
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
|
||||
self.num_hidden_layers = self.n_layer
|
||||
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
|
||||
self.num_hidden_layers = self.n_layers
|
||||
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
|
||||
self.conv_kernel = self.d_conv
|
||||
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
|
||||
self.use_bias = self.bias
|
||||
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
|
||||
self.use_conv_bias = self.conv_bias
|
||||
|
||||
if self.time_step_rank == "auto":
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||
if self.model_type == "falcon_mamba":
|
||||
self.use_bcdt_rms = True
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.weight = mx.random.normal((self.channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,)) if bias else None
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
groups, K, _ = self.weight.shape
|
||||
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=groups)
|
||||
|
||||
if self.bias is not None:
|
||||
y = y + self.bias
|
||||
|
||||
return y, x[:, -K + 1 :, :]
|
||||
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.ssm_state_size = args.state_size
|
||||
self.conv_kernel_size = args.conv_kernel
|
||||
self.intermediate_size = args.intermediate_size
|
||||
self.time_step_rank = int(args.time_step_rank)
|
||||
self.use_conv_bias = args.use_conv_bias
|
||||
self.use_bcdt_rms = args.use_bcdt_rms
|
||||
if self.use_bcdt_rms:
|
||||
self.mixer_norm = lambda x: mx.fast.rms_norm(
|
||||
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
|
||||
)
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
||||
)
|
||||
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=self.intermediate_size,
|
||||
kernel_size=self.conv_kernel_size,
|
||||
bias=self.use_conv_bias,
|
||||
padding=self.conv_kernel_size - 1,
|
||||
)
|
||||
|
||||
self.x_proj = nn.Linear(
|
||||
self.intermediate_size,
|
||||
self.time_step_rank + 2 * self.ssm_state_size,
|
||||
bias=False,
|
||||
)
|
||||
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
||||
|
||||
A = mx.repeat(
|
||||
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
|
||||
repeats=self.intermediate_size,
|
||||
axis=0,
|
||||
)
|
||||
self.A_log = mx.log(A)
|
||||
self.D = mx.ones([self.intermediate_size])
|
||||
|
||||
self.out_proj = nn.Linear(
|
||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||
)
|
||||
|
||||
def ssm_step(self, x, A, state=None):
|
||||
D = self.D
|
||||
deltaBC = self.x_proj(x)
|
||||
delta, B, C = map(
|
||||
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||
mx.split(
|
||||
deltaBC,
|
||||
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
||||
axis=-1,
|
||||
),
|
||||
)
|
||||
if self.use_bcdt_rms:
|
||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||
delta = nn.softplus(self.dt_proj(delta))
|
||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||
if state is not None:
|
||||
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
||||
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
|
||||
y = y + D * x
|
||||
return y, new_state
|
||||
|
||||
def _process_sequence(self, x, conv_cache, state_cache):
|
||||
B, T, D = x.shape
|
||||
xz = self.in_proj(x)
|
||||
x, z = xz.split(indices_or_sections=2, axis=-1)
|
||||
|
||||
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
||||
x = nn.silu(conv_out)
|
||||
|
||||
A = -mx.exp(self.A_log)
|
||||
|
||||
outputs = []
|
||||
current_state = state_cache
|
||||
y = []
|
||||
for t in range(T):
|
||||
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
||||
y.append(y_t)
|
||||
y = mx.stack(y, axis=1)
|
||||
z = self.out_proj(nn.silu(z) * y)
|
||||
return z, (new_conv_cache, current_state)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
if cache is None:
|
||||
conv_cache, state_cache = None, None
|
||||
else:
|
||||
conv_cache, state_cache = cache[0], cache[1]
|
||||
|
||||
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||
x, conv_cache, state_cache
|
||||
)
|
||||
|
||||
if isinstance(cache, MambaCache):
|
||||
cache[0] = new_conv_cache
|
||||
cache[1] = new_state_cache
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.mixer = MambaBlock(args)
|
||||
self.norm = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array, cache):
|
||||
return self.mixer(self.norm(x), cache) + x
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm_f = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array, cache):
|
||||
x = self.embeddings(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, c)
|
||||
return self.norm_f(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.backbone = Mamba(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache=None):
|
||||
B, T = inputs.shape
|
||||
|
||||
x = self.backbone(inputs, cache)
|
||||
|
||||
if self.args.tie_word_embeddings:
|
||||
logits = self.backbone.embeddings.as_linear(x)
|
||||
else:
|
||||
logits = self.lm_head(x)
|
||||
|
||||
return logits
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
return [MambaCache() for _ in range(len(self.layers))]
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.backbone.layers
|
||||
@@ -1,210 +0,0 @@
|
||||
# Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
dim_model_base: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
scale_depth: float
|
||||
scale_emb: float
|
||||
rope_theta: float = 1000000.0
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[str, float]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_heads = n_heads = args.num_attention_heads
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
dims=self.head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=self.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
attn_output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.o_proj(attn_output)
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
self.scale_depth = args.scale_depth
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||
return out
|
||||
|
||||
|
||||
class MiniCPMModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
assert self.vocab_size > 0
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = MiniCPMModel(args)
|
||||
|
||||
if not self.args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
|
||||
if not self.args.tie_word_embeddings:
|
||||
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
||||
else:
|
||||
out = out @ self.model.embed_tokens.weight.T
|
||||
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,220 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int = 32000
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 14336
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_experts_per_tok: int = 2
|
||||
num_key_value_heads: int = 8
|
||||
num_local_experts: int = 8
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 1e6
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_heads = args.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MixtralSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_dim = args.hidden_size
|
||||
self.ffn_dim = args.intermediate_size
|
||||
self.num_experts = args.num_local_experts
|
||||
self.num_experts_per_tok = args.num_experts_per_tok
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
|
||||
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gates = self.gate(x)
|
||||
|
||||
k = self.num_experts_per_tok
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = MixtralAttention(args)
|
||||
|
||||
self.block_sparse_moe = MixtralSparseMoeBlock(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.block_sparse_moe(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class MixtralModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = MixtralModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(
|
||||
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
|
||||
)
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
|
||||
mx.stack(to_join)
|
||||
)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,220 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
hidden_act: str
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
partial_rotary_factor: float = 0.5
|
||||
rope_theta: float = 10000.0
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.rope_scaling:
|
||||
if not "factor" in self.rope_scaling:
|
||||
raise ValueError(f"rope_scaling must contain 'factor'")
|
||||
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
||||
"rope_type"
|
||||
)
|
||||
if rope_type is None:
|
||||
raise ValueError(
|
||||
f"rope_scaling must contain either 'type' or 'rope_type'"
|
||||
)
|
||||
if rope_type not in ["linear"]:
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu_squared(x):
|
||||
return nn.relu(x).square()
|
||||
|
||||
|
||||
class NemotronLayerNorm1P(nn.LayerNorm):
|
||||
def __call__(self, x):
|
||||
weight = self.weight + 1 if "weight" in self else None
|
||||
bias = self.bias if "bias" in self else None
|
||||
return mx.fast.layer_norm(x, weight, bias, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
self.partial_rotary_factor = args.partial_rotary_factor
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||
assert isinstance(args.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / args.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
mlp_bias = args.mlp_bias
|
||||
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(relu_squared(self.up_proj(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
|
||||
self.post_attention_layernorm = NemotronLayerNorm1P(
|
||||
args.hidden_size, eps=args.norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class NemotronModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = NemotronModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,180 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
d_model: int
|
||||
n_layers: int
|
||||
mlp_hidden_size: int
|
||||
n_heads: int
|
||||
vocab_size: int
|
||||
embedding_size: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
mlp_ratio: int = 4
|
||||
weight_tying: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.mlp_hidden_size = (
|
||||
self.mlp_hidden_size
|
||||
if self.mlp_hidden_size is not None
|
||||
else self.mlp_ratio * self.d_model
|
||||
)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
dim = args.d_model
|
||||
|
||||
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
|
||||
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
|
||||
|
||||
self.att_norm = nn.LayerNorm(dim, affine=False)
|
||||
self.ff_norm = nn.LayerNorm(dim, affine=False)
|
||||
|
||||
head_dim = dim // self.n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.att_proj = nn.Linear(dim, 3 * dim, bias=False)
|
||||
self.attn_out = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
self.args = args
|
||||
|
||||
def attend(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.attn_out(output)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attend(self.att_norm(x), mask, cache)
|
||||
h = x + r
|
||||
|
||||
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
||||
|
||||
out = h + self.ff_out(nn.silu(x2) * x1)
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_layers = args.n_layers
|
||||
self.weight_tying = args.weight_tying
|
||||
|
||||
self.wte = nn.Embedding(args.embedding_size, args.d_model)
|
||||
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
if not self.weight_tying:
|
||||
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||
self.norm = nn.LayerNorm(args.d_model, affine=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
||||
for block, c in zip(self.blocks, cache):
|
||||
h = block(h, mask, c)
|
||||
|
||||
h = self.norm(h)
|
||||
|
||||
if self.weight_tying:
|
||||
return self.wte.as_linear(h), cache
|
||||
|
||||
return self.ff_out(h)
|
||||
|
||||
|
||||
class OlmoModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.transformer = Transformer(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
return self.transformer(inputs, mask, cache)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = OlmoModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
return self.model(inputs, mask, cache)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.transformer.blocks
|
||||
@@ -1,212 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
args.rope_traditional,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
|
||||
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.post_attention_layernorm(self.self_attn(x, mask, cache))
|
||||
h = x + r
|
||||
r = self.post_feedforward_layernorm(self.mlp(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,223 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
head_dim: int
|
||||
num_transformer_layers: int
|
||||
model_dim: int
|
||||
vocab_size: int
|
||||
ffn_dim_divisor: int
|
||||
num_query_heads: List
|
||||
num_kv_heads: List
|
||||
ffn_multipliers: List
|
||||
ffn_with_glu: bool = True
|
||||
normalize_qk_projections: bool = True
|
||||
share_input_output_layers: bool = True
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_freq_constant: float = 10000
|
||||
|
||||
|
||||
def make_divisible(
|
||||
v: Union[float, int],
|
||||
divisor: Optional[int] = 8,
|
||||
min_value: Optional[Union[float, int]] = None,
|
||||
) -> Union[float, int]:
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by the divisor
|
||||
It can be seen at:
|
||||
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
|
||||
Args:
|
||||
v: input value
|
||||
divisor: default to 8
|
||||
min_value: minimum divisor value
|
||||
Returns:
|
||||
new_v: new divisible value
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
self.layer_id = layer_id
|
||||
self.model_dim = model_dim = args.model_dim
|
||||
|
||||
self.n_heads = n_heads = args.num_query_heads[layer_id]
|
||||
self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id]
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
op_size = (n_heads + (n_kv_heads * 2)) * head_dim
|
||||
self.qkv_proj = nn.Linear(model_dim, op_size, bias=False)
|
||||
self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False)
|
||||
|
||||
self.normalize_qk_projections = args.normalize_qk_projections
|
||||
|
||||
if self.normalize_qk_projections:
|
||||
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
|
||||
qkv = qkv.reshape(
|
||||
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
if self.normalize_qk_projections:
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
dim = args.model_dim
|
||||
ffn_multiplier = args.ffn_multipliers[layer_id]
|
||||
|
||||
intermediate_dim = int(
|
||||
make_divisible(
|
||||
ffn_multiplier * args.model_dim,
|
||||
divisor=args.ffn_dim_divisor,
|
||||
)
|
||||
)
|
||||
|
||||
self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False)
|
||||
self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.proj_1(x)
|
||||
gate, x = mx.split(x, 2, axis=-1)
|
||||
return self.proj_2(nn.silu(gate) * x)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
dim = args.model_dim
|
||||
self.attn = Attention(args, layer_id=layer_id)
|
||||
self.ffn = MLP(args, layer_id=layer_id)
|
||||
self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
|
||||
self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.attn_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.ffn(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class OpenELMModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_transformer_layers = args.num_transformer_layers
|
||||
assert self.vocab_size > 0
|
||||
self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim)
|
||||
self.layers = [
|
||||
TransformerBlock(args, layer_id=layer_id)
|
||||
for layer_id in range(self.num_transformer_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.token_embeddings(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = OpenELMModel(args)
|
||||
if not args.share_input_output_layers:
|
||||
self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.share_input_output_layers:
|
||||
out = self.transformer.token_embeddings.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.layers
|
||||
@@ -1,179 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "phi"
|
||||
max_position_embeddings: int = 2048
|
||||
vocab_size: int = 51200
|
||||
hidden_size: int = 2560
|
||||
num_attention_heads: int = 32
|
||||
num_hidden_layers: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
partial_rotary_factor: float = 0.4
|
||||
intermediate_size: int = 10240
|
||||
layer_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class PhiAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.repeats = self.num_heads // self.num_key_value_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.dense = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=True
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
B, L, D = queries.shape
|
||||
n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(
|
||||
B,
|
||||
L,
|
||||
n_heads,
|
||||
-1,
|
||||
).moveaxis(1, 2)
|
||||
keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||
values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
output = scaled_dot_product_attention(
|
||||
queries.astype(mx.float32),
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
).astype(values.dtype)
|
||||
|
||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class PhiDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = PhiAttention(config=config)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = PhiMLP(config)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class PhiModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, mask, c)
|
||||
return self.final_layernorm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.model = PhiModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,207 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .su_rope import SuScaledRotaryEmbedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
|
||||
max_position_embeddings: int = 131072
|
||||
original_max_position_embeddings: int = 4096
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"long_factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
|
||||
print(
|
||||
"[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
|
||||
)
|
||||
self.rope_scaling = None
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
||||
self.rope = SuScaledRotaryEmbedding(
|
||||
head_dim,
|
||||
base=args.rope_theta,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||
short_factor=args.rope_scaling["short_factor"],
|
||||
long_factor=args.rope_scaling["long_factor"],
|
||||
)
|
||||
else:
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||
assert isinstance(args.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / args.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
query_pos = self.n_heads * self.head_dim
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.gate_up_proj(x)
|
||||
gate, x = mx.split(x, 2, axis=-1)
|
||||
return self.down_proj(nn.silu(gate) * x)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Phi3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,313 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
dense_attention_every_n_layers: int
|
||||
ff_intermediate_size: int
|
||||
gegelu_limit: float
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
mup_attn_multiplier: float = 1.0
|
||||
mup_use_scaling: bool = True
|
||||
mup_embedding_multiplier: float = 10.0
|
||||
mup_width_multiplier: float = 8.0
|
||||
rope_embedding_base: float = 1000000
|
||||
rope_position_scale: float = 1.0
|
||||
blocksparse_block_size: int = 64
|
||||
blocksparse_num_local_blocks: int = 16
|
||||
blocksparse_vert_stride: int = 8
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gegelu_impl(a_gelu, a_linear, limit):
|
||||
a_gelu = mx.where(
|
||||
mx.isinf(a_gelu),
|
||||
a_gelu,
|
||||
mx.clip(a_gelu, a_min=None, a_max=limit),
|
||||
)
|
||||
a_linear = mx.where(
|
||||
mx.isinf(a_linear),
|
||||
a_linear,
|
||||
mx.clip(a_linear, a_min=-limit, a_max=limit),
|
||||
)
|
||||
out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
|
||||
return out_gelu * (a_linear + 1.0)
|
||||
|
||||
|
||||
def gegelu(x, limit):
|
||||
a_gelu, a_linear = x[..., ::2], x[..., 1::2]
|
||||
return gegelu_impl(a_gelu, a_linear, limit)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_q_per_kv = n_heads // n_kv_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
|
||||
self.query_key_value = nn.Linear(
|
||||
dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
|
||||
)
|
||||
self.dense = nn.Linear(dim, dim)
|
||||
|
||||
if args.mup_use_scaling:
|
||||
norm_factor = head_dim / args.mup_attn_multiplier
|
||||
else:
|
||||
norm_factor = math.sqrt(head_dim)
|
||||
self.scale = 1.0 / norm_factor
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=False,
|
||||
base=args.rope_embedding_base,
|
||||
scale=args.rope_position_scale,
|
||||
)
|
||||
|
||||
if layer_idx % args.dense_attention_every_n_layers == 0:
|
||||
self.block_sparse = True
|
||||
self.blocksparse_block_size = args.blocksparse_block_size
|
||||
if self.blocksparse_block_size not in (32, 64):
|
||||
raise ValueError(
|
||||
f"Unsupported block size {self.blocksparse_block_size}"
|
||||
)
|
||||
self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
|
||||
self.blocksparse_vert_stride = args.blocksparse_vert_stride
|
||||
else:
|
||||
self.block_sparse = False
|
||||
|
||||
def _block_sparse_mask(self, q_len, kv_len):
|
||||
vert_stride = self.blocksparse_vert_stride
|
||||
local_blocks = self.blocksparse_num_local_blocks
|
||||
block_size = self.blocksparse_block_size
|
||||
n_heads = self.n_heads
|
||||
|
||||
kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
q_blocks = (q_len + block_size - 1) // block_size
|
||||
q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
|
||||
k_pos = mx.arange(kv_blocks)[None, None]
|
||||
|
||||
mask_vert_strided = (
|
||||
mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
|
||||
) % vert_stride
|
||||
mask_vert_strided = (mask_vert_strided == 0)[:, None, :]
|
||||
|
||||
block_mask = (q_pos >= k_pos) & (
|
||||
(q_pos - k_pos < local_blocks) | mask_vert_strided
|
||||
)
|
||||
block_mask = block_mask.reshape(
|
||||
self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
|
||||
)
|
||||
dense_mask = mx.repeat(
|
||||
mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
|
||||
)
|
||||
return block_mask, dense_mask[..., -q_len:, :kv_len]
|
||||
|
||||
def _block_sparse_attention(self, queries, keys, values, scale, mask):
|
||||
queries = scale * queries
|
||||
B = queries.shape[0]
|
||||
L = queries.shape[2]
|
||||
queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
|
||||
keys = mx.expand_dims(keys, 2)
|
||||
values = mx.expand_dims(values, 2)
|
||||
|
||||
# TODO get rid of dense mask if we have a fill value
|
||||
block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
|
||||
scores = queries @ mx.swapaxes(keys, -1, -2)
|
||||
# TODO, uncomment when faster
|
||||
# scores = mx.block_masked_mm(
|
||||
# queries,
|
||||
# mx.swapaxes(keys, -1, -2),
|
||||
# mask_out=block_mask,
|
||||
# block_size=self.blocksparse_block_size,
|
||||
# )
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = scores + mx.where(
|
||||
dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
|
||||
)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
output = scores @ values
|
||||
# TODO, uncomment when faster
|
||||
# output = mx.block_masked_mm(
|
||||
# scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
|
||||
# )
|
||||
return mx.reshape(output, (B, self.n_heads, L, -1))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.query_key_value(x)
|
||||
qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
|
||||
queries = qkv[..., :-2, :].flatten(-3, -2)
|
||||
keys = qkv[..., -2, :]
|
||||
values = qkv[..., -1, :]
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
if self.block_sparse:
|
||||
output = self._block_sparse_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.ff_intermediate_size
|
||||
self.gegelu_limit = args.gegelu_limit
|
||||
self.up_proj = nn.Linear(dim, 2 * hidden_dim)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.up_proj(x)
|
||||
return self.down_proj(gegelu(x, self.gegelu_limit))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args, layer_idx)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_epsilon
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layer_norm_epsilon,
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Phi3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.mup_embedding_multiplier = args.mup_embedding_multiplier
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args, layer_idx=l)
|
||||
for l in range(args.num_hidden_layers)
|
||||
]
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_epsilon
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
if self.mup_embedding_multiplier:
|
||||
h = self.mup_embedding_multiplier * h
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.final_layernorm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
self.args = args
|
||||
self.mup_width_multiplier = args.mup_width_multiplier
|
||||
self._dummy_tokenizer_ids = mx.array(
|
||||
[100256, 100258, 100259, 100260, 100264, 100265]
|
||||
+ list(range(100267, 100352))
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
if self.mup_width_multiplier:
|
||||
out = out / self.mup_width_multiplier
|
||||
out[self._dummy_tokenizer_ids] = -float("inf")
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .su_rope import SuScaledRotaryEmbedding
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "phimoe"
|
||||
vocab_size: int = 32064
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 6400
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 131072
|
||||
original_max_position_embeddings: int = 4096
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_scaling: Dict[str, Union[float, List[float]]] = None
|
||||
num_local_experts: int = 16
|
||||
num_experts_per_tok: int = 2
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
||||
|
||||
self.rope = SuScaledRotaryEmbedding(
|
||||
head_dim,
|
||||
base=args.rope_theta,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||
short_factor=args.rope_scaling["short_factor"],
|
||||
long_factor=args.rope_scaling["long_factor"],
|
||||
short_mscale=args.rope_scaling["short_mscale"],
|
||||
long_mscale=args.rope_scaling["long_mscale"],
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class PhiMoESparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_dim = args.hidden_size
|
||||
self.ffn_dim = args.intermediate_size
|
||||
self.num_experts = args.num_local_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gates = self.gate(x)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class PhiMoEDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.block_sparse_moe = PhiMoESparseMoeBlock(args)
|
||||
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
hidden_states = self.input_layernorm(x)
|
||||
hidden_states = self.self_attn(hidden_states, mask=mask, cache=cache)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PhiMoEModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.args = args
|
||||
self.model = PhiMoEModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(
|
||||
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
|
||||
)
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
|
||||
mx.stack(to_join)
|
||||
)
|
||||
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,202 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchMLP
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
model_type: str
|
||||
num_vocab: int = 51200
|
||||
model_dim: int = 2560
|
||||
num_heads: int = 32
|
||||
num_layers: int = 32
|
||||
rotary_dim: int = 32
|
||||
num_experts_per_tok: int = 2
|
||||
num_local_experts: int = 4
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.Wqkv(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries.astype(mx.float32),
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
).astype(values.dtype)
|
||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class MOE(nn.Module):
|
||||
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_experts = args.num_local_experts
|
||||
self.num_experts_per_tok = args.num_experts_per_tok
|
||||
self.switch_mlp = SwitchMLP(
|
||||
self.dim, self.hidden_dim, self.num_experts, bias=True
|
||||
)
|
||||
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gates = self.gate(x)
|
||||
|
||||
k = self.num_experts_per_tok
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = nn.LayerNorm(dims)
|
||||
self.moe = MOE(config, dims, mlp_dims)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.ln(x)
|
||||
attn_h = self.mixer(h, mask, cache)
|
||||
ff_h = self.moe(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embd = Embd(config)
|
||||
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embd(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
x = layer(x, mask, c)
|
||||
return x
|
||||
|
||||
|
||||
class Embd(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.wte(x)
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.transformer = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_layers):
|
||||
prefix = f"transformer.h.{l}"
|
||||
for n in ["fc1", "fc2"]:
|
||||
for k in ["weight", "scales", "biases", "bias"]:
|
||||
if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
@@ -1,214 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
n_shared_head: int = 8
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
head_dim = self.hidden_size // config.num_attention_heads
|
||||
|
||||
self.q_num_heads = config.num_attention_heads
|
||||
self.qk_dim = self.v_dim = head_dim
|
||||
self.k_num_heads = self.v_num_heads = int(
|
||||
np.ceil(self.q_num_heads / config.n_shared_head)
|
||||
)
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.q_num_heads * self.qk_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.k_num_heads * self.qk_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.v_num_heads * self.v_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||
)
|
||||
self.rotary_emb = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=config.rope_traditional,
|
||||
base=config.rope_theta,
|
||||
scale=1.0,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
|
||||
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self.scale,
|
||||
mask=attention_mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
|
||||
|
||||
|
||||
class PlamoDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
# from LlamaDecoder
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states_sa = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states_sa + hidden_states_mlp
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
PlamoDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
|
||||
class PlamoModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
|
||||
for layer, c in zip(self.layers.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = PlamoModel(args)
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
args.hidden_size, args.vocab_size, bias=False
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
@@ -1,159 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 2048
|
||||
num_attention_heads: int = 16
|
||||
num_hidden_layers: int = 24
|
||||
kv_channels: int = 128
|
||||
max_position_embeddings: int = 8192
|
||||
layer_norm_epsilon: float = 1e-6
|
||||
intermediate_size: int = 11008
|
||||
no_bias: bool = True
|
||||
vocab_size: int = 151936
|
||||
num_key_value_heads = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
hidden_size = args.hidden_size
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
|
||||
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
|
||||
|
||||
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
|
||||
|
||||
proj_size = args.kv_channels * self.num_attention_heads
|
||||
|
||||
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
|
||||
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
|
||||
|
||||
self.scale = hidden_size_per_attention_head**-0.5
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.c_attn(x)
|
||||
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
B, L, _ = q.shape
|
||||
|
||||
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(
|
||||
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||
)
|
||||
self.c_proj = nn.Linear(
|
||||
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
a1 = self.w1(x)
|
||||
a2 = self.w2(x)
|
||||
return self.c_proj(a1 * nn.silu(a2))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
residual = x
|
||||
x = self.ln_1(x)
|
||||
x = self.attn(x, mask=mask, cache=cache)
|
||||
residual = x + residual
|
||||
x = self.ln_2(residual)
|
||||
x = self.mlp(x)
|
||||
x = x + residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class QwenModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
x = layer(x, mask, c)
|
||||
|
||||
return self.ln_f(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.transformer = QwenModel(config)
|
||||
self.lm_head = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
||||
)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
@@ -1,201 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 1000000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if self.args.tie_word_embeddings:
|
||||
weights.pop("lm_head.weight", None)
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,241 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_experts_per_tok: int
|
||||
num_experts: int
|
||||
moe_intermediate_size: int
|
||||
shared_expert_intermediate_size: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 1000000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
intermediate_size = args.moe_intermediate_size
|
||||
shared_expert_intermediate_size = args.shared_expert_intermediate_size
|
||||
|
||||
self.num_experts = num_experts = args.num_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
|
||||
self.gate = nn.Linear(dim, num_experts, bias=False)
|
||||
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
|
||||
|
||||
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
|
||||
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
):
|
||||
gates = self.gate(x)
|
||||
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
shared_expert_output = self.shared_expert(x)
|
||||
shared_expert_output = (
|
||||
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
|
||||
)
|
||||
|
||||
return y + shared_expert_output
|
||||
|
||||
|
||||
class Qwen2MoeDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(args)
|
||||
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2MoeModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
Qwen2MoeDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2MoeModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,458 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .cache import MambaCache, RotatingKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
attention_bias: bool
|
||||
conv1d_width: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
logits_soft_cap: float
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
attention_window_size: int
|
||||
vocab_size: int
|
||||
embeddings_scale_by_sqrt_dim: bool = True
|
||||
block_types: Optional[List[str]] = None
|
||||
_block_types: Optional[List[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# For some reason these have different names in 2B and 9B
|
||||
if self.block_types is None:
|
||||
self.block_types = self._block_types
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
def rnn_scan(x, a, h0):
|
||||
assert x.ndim == 3
|
||||
assert a.shape == x.shape[-a.ndim :]
|
||||
assert a.dtype == x.dtype
|
||||
|
||||
if x.shape[1] == 1:
|
||||
# Using scan in sampling mode.
|
||||
if h0 is None:
|
||||
return x, x[:, 0]
|
||||
|
||||
else:
|
||||
y = a * h0[:, None] + x
|
||||
return y, y[:, -1]
|
||||
|
||||
else:
|
||||
# Using scan in linear mode.
|
||||
if h0 is not None:
|
||||
h_t = h0
|
||||
else:
|
||||
B, _, D = x.shape
|
||||
h_t = mx.zeros((B, D), dtype=x.dtype)
|
||||
|
||||
y = mx.zeros_like(x)
|
||||
for t in range(x.shape[1]):
|
||||
h_t = a[:, t] * h_t + x[:, t]
|
||||
y[:, t] = h_t
|
||||
|
||||
return y, h_t
|
||||
|
||||
|
||||
class Conv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,))
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
groups, K, _ = self.weight.shape
|
||||
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=groups)
|
||||
y = y + self.bias
|
||||
|
||||
return y, x[:, -K + 1 :, :]
|
||||
|
||||
|
||||
class RGLRU(nn.Module):
|
||||
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.width // self.num_heads
|
||||
|
||||
self.recurrent_param = mx.zeros((self.width,))
|
||||
|
||||
self.input_gate_weight = mx.zeros(
|
||||
(self.num_heads, self.head_dim, self.head_dim),
|
||||
)
|
||||
self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim))
|
||||
|
||||
self.recurrent_gate_weight = mx.zeros(
|
||||
(self.num_heads, self.head_dim, self.head_dim),
|
||||
)
|
||||
self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
|
||||
def apply_block_linear(h, w, b):
|
||||
h = h.reshape((B, L, self.num_heads, self.head_dim))
|
||||
h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b
|
||||
return mx.sigmoid(h.flatten(2, 3))
|
||||
|
||||
# Gates for x and a.
|
||||
gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias)
|
||||
gate_a = apply_block_linear(
|
||||
x, self.recurrent_gate_weight, self.recurrent_gate_bias
|
||||
)
|
||||
|
||||
# Compute the parameter `A` of the recurrence.
|
||||
log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param)
|
||||
a = mx.exp(log_a)
|
||||
a_square = mx.exp(2 * log_a)
|
||||
|
||||
# Gate the input.
|
||||
gated_x = x * gate_x
|
||||
|
||||
# Apply gamma normalization to the input.
|
||||
multiplier = mx.sqrt(1 - a_square)
|
||||
if cache is None:
|
||||
multiplier[:, 0, :] = 1.0
|
||||
normalized_x = gated_x * multiplier.astype(x.dtype)
|
||||
|
||||
y, last_h = rnn_scan(
|
||||
x=normalized_x,
|
||||
a=a,
|
||||
h0=cache,
|
||||
)
|
||||
|
||||
return y, last_h
|
||||
|
||||
|
||||
class RecurrentBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
lru_width: int = None,
|
||||
conv1d_temporal_width: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.lru_width = lru_width or width
|
||||
self.conv1d_temporal_width = conv1d_temporal_width
|
||||
|
||||
self.linear_y = nn.Linear(width, self.lru_width)
|
||||
self.linear_x = nn.Linear(width, self.lru_width)
|
||||
self.linear_out = nn.Linear(self.lru_width, width)
|
||||
self.conv_1d = Conv1d(
|
||||
channels=self.lru_width,
|
||||
kernel_size=self.conv1d_temporal_width,
|
||||
)
|
||||
self.rg_lru = RGLRU(
|
||||
width=self.lru_width,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
# y branch.
|
||||
y = self.linear_y(x)
|
||||
y = nn.gelu_approx(y)
|
||||
|
||||
# x branch.
|
||||
x = self.linear_x(x)
|
||||
if cache is None:
|
||||
cache = [None, None]
|
||||
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
|
||||
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
|
||||
|
||||
x = x * y
|
||||
x = self.linear_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LocalAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.scale = (width // num_heads) ** (-0.5)
|
||||
|
||||
self.head_dim = self.width // self.num_heads
|
||||
self.q_proj = nn.Linear(self.width, self.width, bias=False)
|
||||
self.k_proj = nn.Linear(self.width, self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.width, self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.width, self.width, bias=True)
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim // 2,
|
||||
traditional=False,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
|
||||
def __init__(self, width: int, expanded_width: int):
|
||||
super().__init__()
|
||||
self.up_proj = nn.Linear(width, expanded_width // 2)
|
||||
self.gate_proj = nn.Linear(width, expanded_width // 2)
|
||||
self.down_proj = nn.Linear(expanded_width // 2, width)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
gate = self.gate_proj(x)
|
||||
x = self.up_proj(x)
|
||||
return self.down_proj(nn.gelu_approx(gate) * x)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
mlp_expanded_width: int,
|
||||
num_heads: int,
|
||||
attention_window_size: int,
|
||||
temporal_block_type: str,
|
||||
lru_width: Optional[int] = None,
|
||||
conv1d_temporal_width: int = 4,
|
||||
):
|
||||
"""Initializes the residual block.
|
||||
|
||||
Args:
|
||||
width: The width of the block.
|
||||
mlp_expanded_width: The width of the expansion inside the MLP block.
|
||||
num_heads: The number of heads for the Attention or the RG-LRU.
|
||||
attention_window_size: The window size for the local attention block.
|
||||
temporal_block_type: Either "recurrent" or "attention", specifying the
|
||||
type of recurrent block to use.
|
||||
lru_width: The width of the RG-LRU if different from `width`.
|
||||
conv1d_temporal_width: The width of the temporal convolution.
|
||||
"""
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.mlp_expanded_width = mlp_expanded_width
|
||||
self.num_heads = num_heads
|
||||
self.attention_window_size = attention_window_size
|
||||
self.temporal_block_type = temporal_block_type
|
||||
self.lru_width = lru_width
|
||||
self.conv1d_temporal_width = conv1d_temporal_width
|
||||
|
||||
self.temporal_pre_norm = RMSNorm(width)
|
||||
if self.temporal_block_type == "recurrent":
|
||||
self.temporal_block = RecurrentBlock(
|
||||
width=self.width,
|
||||
num_heads=self.num_heads,
|
||||
lru_width=self.lru_width,
|
||||
conv1d_temporal_width=self.conv1d_temporal_width,
|
||||
)
|
||||
|
||||
else:
|
||||
self.temporal_block = LocalAttentionBlock(
|
||||
width=self.width,
|
||||
num_heads=self.num_heads,
|
||||
window_size=self.attention_window_size,
|
||||
)
|
||||
|
||||
self.channel_pre_norm = RMSNorm(width)
|
||||
self.mlp_block = MLPBlock(
|
||||
width=self.width,
|
||||
expanded_width=self.mlp_expanded_width,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
raw_x = x
|
||||
|
||||
inputs_normalized = self.temporal_pre_norm(raw_x)
|
||||
|
||||
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
|
||||
residual = x + raw_x
|
||||
|
||||
x = self.channel_pre_norm(residual)
|
||||
x = self.mlp_block(x)
|
||||
|
||||
x = x + residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Griffin(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
|
||||
block_types = config.block_types
|
||||
|
||||
self.layers = [
|
||||
ResidualBlock(
|
||||
width=config.hidden_size,
|
||||
mlp_expanded_width=config.intermediate_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
attention_window_size=config.attention_window_size,
|
||||
temporal_block_type=block_types[i % len(block_types)],
|
||||
lru_width=None,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tokens,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
x = self.embed_tokens(tokens)
|
||||
if self.scale_by_sqrt_dim:
|
||||
x = x * math.sqrt(x.shape[-1])
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
if block.temporal_block_type != "recurrent":
|
||||
mask_cache = [cache[i]]
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, mask_cache)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
|
||||
return self.final_norm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
self.args = config
|
||||
self.model = Griffin(config)
|
||||
self.model_type = config.model_type
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
tokens: Sequence of input tokens.
|
||||
"""
|
||||
logits = self.model(tokens, mask=mask, cache=cache)
|
||||
if "lm_head" in self:
|
||||
logits = self.lm_head(logits)
|
||||
else:
|
||||
logits = self.model.embed_tokens.as_linear(logits)
|
||||
|
||||
c = self.args.logits_soft_cap
|
||||
if c:
|
||||
logits = mx.tanh(logits / c) * c
|
||||
return logits
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv_1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
if "lm_head.weight" not in weights:
|
||||
self.pop("lm_head")
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
cache = []
|
||||
for layer in self.layers:
|
||||
if layer.temporal_block_type == "recurrent":
|
||||
cache.append(MambaCache())
|
||||
else:
|
||||
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
|
||||
return cache
|
||||
@@ -1,91 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class Llama3RoPE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scaling_config: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.traditional = traditional
|
||||
|
||||
factor = scaling_config["factor"]
|
||||
low_freq_factor = scaling_config.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = scaling_config.get("high_freq_factor", 4.0)
|
||||
old_context_len = scaling_config.get(
|
||||
"original_max_position_embeddings",
|
||||
8192,
|
||||
)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
freqs = base ** (mx.arange(0, dims, 2) / dims)
|
||||
wavelens = 2 * mx.pi * freqs
|
||||
|
||||
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
||||
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
||||
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
|
||||
high_freq_factor - low_freq_factor
|
||||
)
|
||||
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
||||
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"{self.dims}, traditional={self.traditional}, "
|
||||
f"max_position_embeddings={self.max_position_embeddings}"
|
||||
)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
def initialize_rope(
|
||||
dims,
|
||||
base,
|
||||
traditional,
|
||||
scaling_config: Optional[dict] = None,
|
||||
max_position_embeddings: Optional[int] = None,
|
||||
):
|
||||
if scaling_config is not None:
|
||||
rope_type = scaling_config.get("type") or scaling_config.get(
|
||||
"rope_type", "default"
|
||||
)
|
||||
else:
|
||||
rope_type = "default"
|
||||
|
||||
if rope_type in ["default", "linear"]:
|
||||
scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0
|
||||
return nn.RoPE(dims, traditional=traditional, base=base, scale=scale)
|
||||
|
||||
elif rope_type == "llama3":
|
||||
return Llama3RoPE(
|
||||
dims=dims,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scaling_config=scaling_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported RoPE type {rope_type}")
|
||||
@@ -1,211 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
intermediate_size: int
|
||||
rope_theta: float
|
||||
use_qkv_bias: bool
|
||||
partial_rotary_factor: float
|
||||
layer_norm_eps: float
|
||||
use_parallel_residual: bool = False
|
||||
qk_layernorm: bool = False
|
||||
|
||||
|
||||
class LayerNormPerHead(nn.Module):
|
||||
|
||||
def __init__(self, head_dim, num_heads, eps):
|
||||
super().__init__()
|
||||
self.norms = [
|
||||
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
|
||||
]
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
w = mx.stack([n.weight for n in self.norms])
|
||||
return w * mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
bias=config.use_qkv_bias,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
bias=config.use_qkv_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
self.qk_layernorm = config.qk_layernorm
|
||||
if self.qk_layernorm:
|
||||
self.q_layernorm = LayerNormPerHead(
|
||||
self.head_dim, self.num_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
self.k_layernorm = LayerNormPerHead(
|
||||
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
B, L, D = queries.shape
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
|
||||
if self.qk_layernorm:
|
||||
queries = self.q_layernorm(queries)
|
||||
keys = self.k_layernorm(keys)
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
keys = keys.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=scale, mask=mask
|
||||
).astype(values.dtype)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config=config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
if not self.use_parallel_residual:
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
r = self.self_attn(h, mask, cache)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
out = x + r + self.mlp(h)
|
||||
else:
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class StableLM(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, mask, cache=c)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.model = StableLM(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,169 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
norm_epsilon: float = 1e-5
|
||||
vocab_size: int = 49152
|
||||
rope_theta: float = 100000
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // args.num_attention_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
||||
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.norm_epsilon
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Starcoder2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Starcoder2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user