mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
2 Commits
4b2a0df237
...
openlm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97939cc86e | ||
|
|
7c6ced183d |
@@ -17,6 +17,30 @@ jobs:
|
|||||||
pre-commit run --all
|
pre-commit run --all
|
||||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||||
|
|
||||||
|
mlx_lm_build_and_test:
|
||||||
|
macos:
|
||||||
|
xcode: "15.2.0"
|
||||||
|
resource_class: macos.m1.large.gen1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
brew install python@3.8
|
||||||
|
python3.8 -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install unittest-xml-reporting
|
||||||
|
cd llms/
|
||||||
|
pip install -e .
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python -m xmlrunner discover -v llms/tests -o test-results/
|
||||||
|
- store_test_results:
|
||||||
|
path: test-results
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
@@ -24,6 +48,7 @@ workflows:
|
|||||||
pattern: "^(?!pull/)[-\\w]+$"
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
value: << pipeline.git.branch >>
|
value: << pipeline.git.branch >>
|
||||||
jobs:
|
jobs:
|
||||||
|
- mlx_lm_build_and_test
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
@@ -36,5 +61,7 @@ workflows:
|
|||||||
type: approval
|
type: approval
|
||||||
- apple/authenticate:
|
- apple/authenticate:
|
||||||
context: pr-approval
|
context: pr-approval
|
||||||
|
- mlx_lm_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -6,9 +6,6 @@ __pycache__/
|
|||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
# Vim
|
|
||||||
*.swp
|
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 25.1.0
|
rev: 24.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 6.0.0
|
rev: 5.13.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args:
|
args:
|
||||||
|
|||||||
@@ -14,4 +14,3 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Markus Enzweiler: Added the `cvae` examples.
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||||
- Shiyu Li: Added the `Segment Anything Model`.
|
- Shiyu Li: Added the `Segment Anything Model`.
|
||||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`.
|
|
||||||
15
README.md
15
README.md
@@ -4,12 +4,12 @@ This repo contains a variety of standalone examples using the [MLX
|
|||||||
framework](https://github.com/ml-explore/mlx).
|
framework](https://github.com/ml-explore/mlx).
|
||||||
|
|
||||||
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
||||||
Some more useful examples are listed below. Check-out [MLX
|
|
||||||
LM](https://github.com/ml-explore/mlx-lm) for a more fully featured Python
|
Some more useful examples are listed below.
|
||||||
package for LLMs with MLX.
|
|
||||||
|
|
||||||
### Text Models
|
### Text Models
|
||||||
|
|
||||||
|
- [MLX LM](llms/README.md) a package for LLM text generation, fine-tuning, and more.
|
||||||
- [Transformer language model](transformer_lm) training.
|
- [Transformer language model](transformer_lm) training.
|
||||||
- Minimal examples of large scale text generation with [LLaMA](llms/llama),
|
- Minimal examples of large scale text generation with [LLaMA](llms/llama),
|
||||||
[Mistral](llms/mistral), and more in the [LLMs](llms) directory.
|
[Mistral](llms/mistral), and more in the [LLMs](llms) directory.
|
||||||
@@ -20,23 +20,18 @@ package for LLMs with MLX.
|
|||||||
|
|
||||||
### Image Models
|
### Image Models
|
||||||
|
|
||||||
- Generating images
|
|
||||||
- [FLUX](flux)
|
|
||||||
- [Stable Diffusion or SDXL](stable_diffusion)
|
|
||||||
- Image classification using [ResNets on CIFAR-10](cifar).
|
- Image classification using [ResNets on CIFAR-10](cifar).
|
||||||
|
- Generating images with [Stable Diffusion or SDXL](stable_diffusion).
|
||||||
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
|
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
|
||||||
|
|
||||||
### Audio Models
|
### Audio Models
|
||||||
|
|
||||||
- Speech recognition with [OpenAI's Whisper](whisper).
|
- Speech recognition with [OpenAI's Whisper](whisper).
|
||||||
- Audio compression and generation with [Meta's EnCodec](encodec).
|
|
||||||
- Music generation with [Meta's MusicGen](musicgen).
|
|
||||||
|
|
||||||
### Multimodal models
|
### Multimodal models
|
||||||
|
|
||||||
- Joint text and image embeddings with [CLIP](clip).
|
- Joint text and image embeddings with [CLIP](clip).
|
||||||
- Text generation from image and text inputs with [LLaVA](llava).
|
- Text generation from image and text inputs with [LLaVA](llava).
|
||||||
- Image segmentation with [Segment Anything (SAM)](segment_anything).
|
|
||||||
|
|
||||||
### Other Models
|
### Other Models
|
||||||
|
|
||||||
@@ -46,7 +41,7 @@ package for LLMs with MLX.
|
|||||||
|
|
||||||
### Hugging Face
|
### Hugging Face
|
||||||
|
|
||||||
You can directly use or download converted checkpoints from the [MLX
|
Note: You can now directly download a few converted checkpoints from the [MLX
|
||||||
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||||
We encourage you to join the community and [contribute new
|
We encourage you to join the community and [contribute new
|
||||||
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||||
|
|||||||
@@ -48,17 +48,3 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
|
|||||||
|
|
||||||
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
||||||
We intend to update this example once these features are added.
|
We intend to update this example once these features are added.
|
||||||
|
|
||||||
## Distributed training
|
|
||||||
|
|
||||||
The example also supports distributed data parallel training. You can launch a
|
|
||||||
distributed training as follows:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
$ cat >hostfile.json
|
|
||||||
[
|
|
||||||
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
|
|
||||||
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
|
|
||||||
]
|
|
||||||
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.data.datasets import load_cifar10
|
from mlx.data.datasets import load_cifar10
|
||||||
|
|
||||||
@@ -13,11 +12,8 @@ def get_cifar10(batch_size, root=None):
|
|||||||
x = x.astype("float32") / 255.0
|
x = x.astype("float32") / 255.0
|
||||||
return (x - mean) / std
|
return (x - mean) / std
|
||||||
|
|
||||||
group = mx.distributed.init()
|
|
||||||
|
|
||||||
tr_iter = (
|
tr_iter = (
|
||||||
tr.shuffle()
|
tr.shuffle()
|
||||||
.partition_if(group.size() > 1, group.size(), group.rank())
|
|
||||||
.to_stream()
|
.to_stream()
|
||||||
.image_random_h_flip("image", prob=0.5)
|
.image_random_h_flip("image", prob=0.5)
|
||||||
.pad("image", 0, 4, 4, 0.0)
|
.pad("image", 0, 4, 4, 0.0)
|
||||||
@@ -29,11 +25,6 @@ def get_cifar10(batch_size, root=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
test = load_cifar10(root=root, train=False)
|
test = load_cifar10(root=root, train=False)
|
||||||
test_iter = (
|
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
||||||
test.to_stream()
|
|
||||||
.partition_if(group.size() > 1, group.size(), group.rank())
|
|
||||||
.key_transform("image", normalize)
|
|
||||||
.batch(batch_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
return tr_iter, test_iter
|
return tr_iter, test_iter
|
||||||
|
|||||||
@@ -23,13 +23,6 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
|
|||||||
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||||
|
|
||||||
|
|
||||||
def print_zero(group, *args, **kwargs):
|
|
||||||
if group.rank() != 0:
|
|
||||||
return
|
|
||||||
flush = kwargs.pop("flush", True)
|
|
||||||
print(*args, **kwargs, flush=flush)
|
|
||||||
|
|
||||||
|
|
||||||
def eval_fn(model, inp, tgt):
|
def eval_fn(model, inp, tgt):
|
||||||
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||||
|
|
||||||
@@ -41,20 +34,9 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
world = mx.distributed.init()
|
losses = []
|
||||||
losses = 0
|
accs = []
|
||||||
accuracies = 0
|
samples_per_sec = []
|
||||||
samples_per_sec = 0
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
def average_stats(stats, count):
|
|
||||||
if world.size() == 1:
|
|
||||||
return [s / count for s in stats]
|
|
||||||
|
|
||||||
with mx.stream(mx.cpu):
|
|
||||||
stats = mx.distributed.all_sum(mx.array(stats))
|
|
||||||
count = mx.distributed.all_sum(count)
|
|
||||||
return (stats / count).tolist()
|
|
||||||
|
|
||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
@@ -62,7 +44,6 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
def step(inp, tgt):
|
def step(inp, tgt):
|
||||||
train_step_fn = nn.value_and_grad(model, train_step)
|
train_step_fn = nn.value_and_grad(model, train_step)
|
||||||
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
||||||
grads = nn.utils.average_gradients(grads)
|
|
||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
@@ -71,79 +52,69 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
loss, acc = step(x, y)
|
loss, acc = step(x, y)
|
||||||
mx.eval(loss, acc, state)
|
mx.eval(state)
|
||||||
toc = time.perf_counter()
|
toc = time.perf_counter()
|
||||||
losses += loss.item()
|
loss = loss.item()
|
||||||
accuracies += acc.item()
|
acc = acc.item()
|
||||||
samples_per_sec += x.shape[0] / (toc - tic)
|
losses.append(loss)
|
||||||
count += 1
|
accs.append(acc)
|
||||||
|
throughput = x.shape[0] / (toc - tic)
|
||||||
|
samples_per_sec.append(throughput)
|
||||||
if batch_counter % 10 == 0:
|
if batch_counter % 10 == 0:
|
||||||
l, a, s = average_stats(
|
print(
|
||||||
[losses, accuracies, world.size() * samples_per_sec],
|
|
||||||
count,
|
|
||||||
)
|
|
||||||
print_zero(
|
|
||||||
world,
|
|
||||||
" | ".join(
|
" | ".join(
|
||||||
(
|
(
|
||||||
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||||
f"Train loss {l:.3f}",
|
f"Train loss {loss:.3f}",
|
||||||
f"Train acc {a:.3f}",
|
f"Train acc {acc:.3f}",
|
||||||
f"Throughput: {s:.2f} images/second",
|
f"Throughput: {throughput:.2f} images/second",
|
||||||
)
|
)
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
|
mean_tr_loss = mx.mean(mx.array(losses))
|
||||||
|
mean_tr_acc = mx.mean(mx.array(accs))
|
||||||
|
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
||||||
|
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
||||||
|
|
||||||
|
|
||||||
def test_epoch(model, test_iter, epoch):
|
def test_epoch(model, test_iter, epoch):
|
||||||
accuracies = 0
|
accs = []
|
||||||
count = 0
|
|
||||||
for batch_counter, batch in enumerate(test_iter):
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
x = mx.array(batch["image"])
|
x = mx.array(batch["image"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
acc = eval_fn(model, x, y)
|
acc = eval_fn(model, x, y)
|
||||||
accuracies += acc.item()
|
acc_value = acc.item()
|
||||||
count += 1
|
accs.append(acc_value)
|
||||||
|
mean_acc = mx.mean(mx.array(accs))
|
||||||
with mx.stream(mx.cpu):
|
return mean_acc
|
||||||
accuracies = mx.distributed.all_sum(accuracies)
|
|
||||||
count = mx.distributed.all_sum(count)
|
|
||||||
return (accuracies / count).item()
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
# Initialize the distributed group and report the nodes that showed up
|
|
||||||
world = mx.distributed.init()
|
|
||||||
if world.size() > 1:
|
|
||||||
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
|
|
||||||
|
|
||||||
model = getattr(resnet, args.arch)()
|
model = getattr(resnet, args.arch)()
|
||||||
|
|
||||||
print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
|
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
||||||
|
|
||||||
optimizer = optim.Adam(learning_rate=args.lr)
|
optimizer = optim.Adam(learning_rate=args.lr)
|
||||||
|
|
||||||
train_data, test_data = get_cifar10(args.batch_size)
|
train_data, test_data = get_cifar10(args.batch_size)
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||||
print_zero(
|
print(
|
||||||
world,
|
|
||||||
" | ".join(
|
" | ".join(
|
||||||
(
|
(
|
||||||
f"Epoch: {epoch}",
|
f"Epoch: {epoch}",
|
||||||
f"avg. Train loss {tr_loss:.3f}",
|
f"avg. Train loss {tr_loss.item():.3f}",
|
||||||
f"avg. Train acc {tr_acc:.3f}",
|
f"avg. Train acc {tr_acc.item():.3f}",
|
||||||
f"Throughput: {throughput:.2f} images/sec",
|
f"Throughput: {throughput.item():.2f} images/sec",
|
||||||
)
|
)
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_acc = test_epoch(model, test_data, epoch)
|
test_acc = test_epoch(model, test_data, epoch)
|
||||||
print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
|
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
|
||||||
|
|
||||||
train_data.reset()
|
train_data.reset()
|
||||||
test_data.reset()
|
test_data.reset()
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path:
|
def get_model_path(path_or_hf_repo: str) -> Path:
|
||||||
model_path = Path(path_or_hf_repo)
|
model_path = Path(path_or_hf_repo)
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
model_path = Path(
|
model_path = Path(
|
||||||
@@ -74,7 +74,6 @@ def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path:
|
|||||||
"*.json",
|
"*.json",
|
||||||
"*.txt",
|
"*.txt",
|
||||||
],
|
],
|
||||||
force_download=force_download,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return model_path
|
return model_path
|
||||||
@@ -108,20 +107,14 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
default="float32",
|
default="float32",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"-f",
|
|
||||||
"--force-download",
|
|
||||||
help="Force download the model from Hugging Face.",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
torch_path = get_model_path(args.hf_repo, args.force_download)
|
torch_path = get_model_path(args.hf_repo)
|
||||||
mlx_path = Path(args.mlx_path)
|
mlx_path = Path(args.mlx_path)
|
||||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
|
torch_weights = torch.load(torch_path / "pytorch_model.bin")
|
||||||
print("[INFO] Converting")
|
print("[INFO] Converting")
|
||||||
mlx_weights = {
|
mlx_weights = {
|
||||||
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
# Mirror of the Linear Probe Evaluation Script
|
|
||||||
# from the official CLIP Repository.
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
from image_processor import CLIPImageProcessor
|
|
||||||
from mlx.data.datasets import load_cifar10
|
|
||||||
from model import CLIPModel
|
|
||||||
from PIL import Image
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
def get_cifar10(batch_size, root=None):
|
|
||||||
tr = load_cifar10(root=root).batch(batch_size)
|
|
||||||
test = load_cifar10(root=root, train=False).batch(batch_size)
|
|
||||||
|
|
||||||
return tr, test
|
|
||||||
|
|
||||||
|
|
||||||
def get_features(model, image_proc, iter):
|
|
||||||
all_features = []
|
|
||||||
all_labels = []
|
|
||||||
|
|
||||||
for batch in tqdm(iter):
|
|
||||||
image, label = batch["image"], batch["label"]
|
|
||||||
x = image_proc([Image.fromarray(im) for im in image])
|
|
||||||
y = mx.array(label)
|
|
||||||
|
|
||||||
image_embeds = model.get_image_features(x)
|
|
||||||
mx.eval(image_embeds)
|
|
||||||
|
|
||||||
all_features.append(image_embeds)
|
|
||||||
all_labels.append(y)
|
|
||||||
|
|
||||||
return mx.concatenate(all_features), mx.concatenate(all_labels)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model = CLIPModel.from_pretrained("mlx_model")
|
|
||||||
image_proc = CLIPImageProcessor.from_pretrained("mlx_model")
|
|
||||||
|
|
||||||
train_iter, test_iter = get_cifar10(batch_size=256)
|
|
||||||
train_features, train_labels = get_features(model, image_proc, train_iter)
|
|
||||||
test_features, test_labels = get_features(model, image_proc, test_iter)
|
|
||||||
|
|
||||||
# Perform logistic regression
|
|
||||||
# NOTE: The value of C should be determined via a hyperparameter sweep
|
|
||||||
# using a validation split
|
|
||||||
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
|
|
||||||
classifier.fit(train_features, train_labels)
|
|
||||||
|
|
||||||
# Evaluate using the logistic regression classifier
|
|
||||||
predictions = classifier.predict(test_features)
|
|
||||||
accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
|
|
||||||
print(f"Accuracy = {accuracy:.3f}")
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
mlx
|
mlx
|
||||||
mlx-data
|
|
||||||
numpy
|
numpy
|
||||||
transformers
|
transformers
|
||||||
torch
|
torch
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
# EnCodec
|
|
||||||
|
|
||||||
An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
|
|
||||||
generate audio.
|
|
||||||
|
|
||||||
### Setup
|
|
||||||
|
|
||||||
Install the requirements:
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
Optionally install FFmpeg and SciPy for loading and saving audio files,
|
|
||||||
respectively.
|
|
||||||
|
|
||||||
Install [FFmpeg](https://ffmpeg.org/):
|
|
||||||
|
|
||||||
```
|
|
||||||
# on macOS using Homebrew (https://brew.sh/)
|
|
||||||
brew install ffmpeg
|
|
||||||
```
|
|
||||||
|
|
||||||
Install SciPy:
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install scipy
|
|
||||||
```
|
|
||||||
|
|
||||||
### Example
|
|
||||||
|
|
||||||
An example using the model:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import mlx.core as mx
|
|
||||||
from encodec import EncodecModel
|
|
||||||
from utils import load_audio, save_audio
|
|
||||||
|
|
||||||
# Load the 48 KHz model and preprocessor.
|
|
||||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
|
||||||
|
|
||||||
# Load an audio file
|
|
||||||
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
|
|
||||||
|
|
||||||
# Preprocess the audio (this can also be a list of arrays for batched
|
|
||||||
# processing).
|
|
||||||
feats, mask = processor(audio)
|
|
||||||
|
|
||||||
# Encode at the given bandwidth. A lower bandwidth results in more
|
|
||||||
# compression but lower reconstruction quality.
|
|
||||||
@mx.compile
|
|
||||||
def encode(feats, mask):
|
|
||||||
return model.encode(feats, mask, bandwidth=3)
|
|
||||||
|
|
||||||
# Decode to reconstruct the audio
|
|
||||||
@mx.compile
|
|
||||||
def decode(codes, scales, mask):
|
|
||||||
return model.decode(codes, scales, mask)
|
|
||||||
|
|
||||||
|
|
||||||
codes, scales = encode(feats, mask)
|
|
||||||
reconstructed = decode(codes, scales, mask)
|
|
||||||
|
|
||||||
# Trim any padding:
|
|
||||||
reconstructed = reconstructed[0, : len(audio)]
|
|
||||||
|
|
||||||
# Save the audio as a wave file
|
|
||||||
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
|
||||||
```
|
|
||||||
|
|
||||||
The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
|
|
||||||
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
|
|
||||||
in several data types.
|
|
||||||
|
|
||||||
### Optional
|
|
||||||
|
|
||||||
To convert models, use the `convert.py` script. To see the options, run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python convert.py -h
|
|
||||||
```
|
|
||||||
|
|
||||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
|
|
||||||
[code](https://github.com/facebookresearch/encodec) for more details.
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
from encodec import EncodecModel
|
|
||||||
|
|
||||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
|
||||||
|
|
||||||
audio = mx.random.uniform(shape=(288000, 2))
|
|
||||||
feats, mask = processor(audio)
|
|
||||||
mx.eval(model, feats, mask)
|
|
||||||
|
|
||||||
|
|
||||||
@mx.compile
|
|
||||||
def fun():
|
|
||||||
codes, scales = model.encode(feats, mask, bandwidth=3)
|
|
||||||
reconstructed = model.decode(codes, scales, mask)
|
|
||||||
return reconstructed
|
|
||||||
|
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
mx.eval(fun())
|
|
||||||
|
|
||||||
tic = time.time()
|
|
||||||
for _ in range(10):
|
|
||||||
mx.eval(fun())
|
|
||||||
toc = time.time()
|
|
||||||
ms = 1000 * (toc - tic) / 10
|
|
||||||
print(f"Time per it: {ms:.3f}")
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from transformers import AutoProcessor, EncodecModel
|
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
|
||||||
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
|
|
||||||
|
|
||||||
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
|
|
||||||
pt_inputs = processor(
|
|
||||||
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
|
|
||||||
).to("mps")
|
|
||||||
|
|
||||||
|
|
||||||
def fun():
|
|
||||||
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
|
|
||||||
pt_audio = pt_model.decode(
|
|
||||||
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
|
|
||||||
)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
fun()
|
|
||||||
|
|
||||||
tic = time.time()
|
|
||||||
for _ in range(10):
|
|
||||||
fun()
|
|
||||||
toc = time.time()
|
|
||||||
ms = 1000 * (toc - tic) / 10
|
|
||||||
print(f"Time per it: {ms:.3f}")
|
|
||||||
@@ -1,212 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from textwrap import dedent
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any, Dict, Union
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
import encodec
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_from_hub(hf_repo: str) -> Path:
|
|
||||||
model_path = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=hf_repo,
|
|
||||||
allow_patterns=["*.json", "*.safetensors"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
|
|
||||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
||||||
"""
|
|
||||||
Uploads the model to Hugging Face hub.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): Local path to the model.
|
|
||||||
upload_repo (str): Name of the HF repo to upload to.
|
|
||||||
hf_path (str): Path to the original Hugging Face model.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
|
|
||||||
from huggingface_hub import HfApi, ModelCard, logging
|
|
||||||
|
|
||||||
content = dedent(
|
|
||||||
f"""
|
|
||||||
---
|
|
||||||
language: en
|
|
||||||
license: other
|
|
||||||
library: mlx
|
|
||||||
tags:
|
|
||||||
- mlx
|
|
||||||
---
|
|
||||||
|
|
||||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
|
||||||
converted to MLX format from
|
|
||||||
[{hf_path}](https://huggingface.co/{hf_path}).
|
|
||||||
|
|
||||||
This model is intended to be used with the [EnCodec MLX
|
|
||||||
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
card = ModelCard(content)
|
|
||||||
card.save(os.path.join(path, "README.md"))
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
|
||||||
|
|
||||||
api = HfApi()
|
|
||||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
||||||
api.upload_folder(
|
|
||||||
folder_path=path,
|
|
||||||
repo_id=upload_repo,
|
|
||||||
repo_type="model",
|
|
||||||
multi_commits=True,
|
|
||||||
multi_commits_verbose=True,
|
|
||||||
)
|
|
||||||
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
|
||||||
|
|
||||||
|
|
||||||
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
|
||||||
if isinstance(save_path, str):
|
|
||||||
save_path = Path(save_path)
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
total_size = sum(v.nbytes for v in weights.values())
|
|
||||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
|
||||||
mx.save_safetensors(
|
|
||||||
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
|
|
||||||
)
|
|
||||||
|
|
||||||
for weight_name in weights.keys():
|
|
||||||
index_data["weight_map"][weight_name] = "model.safetensors"
|
|
||||||
|
|
||||||
index_data["weight_map"] = {
|
|
||||||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(save_path / "model.safetensors.index.json", "w") as f:
|
|
||||||
json.dump(index_data, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
|
||||||
config: dict,
|
|
||||||
config_path: Union[str, Path],
|
|
||||||
) -> None:
|
|
||||||
"""Save the model configuration to the ``config_path``.
|
|
||||||
|
|
||||||
The final configuration will be sorted before saving for better readability.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): The model configuration.
|
|
||||||
config_path (Union[str, Path]): Model configuration file path.
|
|
||||||
"""
|
|
||||||
# Clean unused keys
|
|
||||||
config.pop("_name_or_path", None)
|
|
||||||
|
|
||||||
# sort the config for better readability
|
|
||||||
config = dict(sorted(config.items()))
|
|
||||||
|
|
||||||
# write the updated config to the config_path (if provided)
|
|
||||||
with open(config_path, "w") as fid:
|
|
||||||
json.dump(config, fid, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
def convert(
|
|
||||||
upload: bool,
|
|
||||||
model: str,
|
|
||||||
dtype: str = None,
|
|
||||||
):
|
|
||||||
hf_repo = f"facebook/encodec_{model}"
|
|
||||||
mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
|
|
||||||
path = fetch_from_hub(hf_repo)
|
|
||||||
save_path = Path("mlx_models")
|
|
||||||
|
|
||||||
weights = mx.load(str(Path(path) / "model.safetensors"))
|
|
||||||
|
|
||||||
with open(path / "config.json", "r") as fid:
|
|
||||||
config = SimpleNamespace(**json.load(fid))
|
|
||||||
|
|
||||||
model = encodec.EncodecModel(config)
|
|
||||||
|
|
||||||
new_weights = {}
|
|
||||||
for k, v in weights.items():
|
|
||||||
basename, pname = k.rsplit(".", 1)
|
|
||||||
if pname == "weight_v":
|
|
||||||
g = weights[basename + ".weight_g"]
|
|
||||||
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
|
|
||||||
k = basename + ".weight"
|
|
||||||
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
|
|
||||||
continue
|
|
||||||
elif "lstm" in basename:
|
|
||||||
w_or_b, ih_or_hh, ln = pname.split("_")
|
|
||||||
if w_or_b == "weight":
|
|
||||||
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
|
|
||||||
elif w_or_b == "bias" and ih_or_hh == "ih":
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
v = v + weights[k.replace("_hh_", "_ih_")]
|
|
||||||
new_pname = "bias"
|
|
||||||
k = basename + "." + ln[1:] + "." + new_pname
|
|
||||||
if "conv.weight" in k:
|
|
||||||
# Possibly a transposed conv which has a different order
|
|
||||||
if "decoder" in k:
|
|
||||||
ln = int(k.split(".")[2])
|
|
||||||
if "conv" in model.decoder.layers[ln] and isinstance(
|
|
||||||
model.decoder.layers[ln].conv, nn.ConvTranspose1d
|
|
||||||
):
|
|
||||||
v = mx.moveaxis(v, 0, 2)
|
|
||||||
else:
|
|
||||||
v = mx.moveaxis(v, 1, 2)
|
|
||||||
else:
|
|
||||||
v = mx.moveaxis(v, 1, 2)
|
|
||||||
|
|
||||||
new_weights[k] = v
|
|
||||||
weights = new_weights
|
|
||||||
|
|
||||||
model.load_weights(list(weights.items()))
|
|
||||||
|
|
||||||
if dtype is not None:
|
|
||||||
t = getattr(mx, dtype)
|
|
||||||
weights = {k: v.astype(t) for k, v in weights.items()}
|
|
||||||
|
|
||||||
if isinstance(save_path, str):
|
|
||||||
save_path = Path(save_path)
|
|
||||||
|
|
||||||
save_weights(save_path, weights)
|
|
||||||
|
|
||||||
save_config(vars(config), config_path=save_path / "config.json")
|
|
||||||
|
|
||||||
if upload:
|
|
||||||
upload_to_hub(save_path, mlx_repo, hf_repo)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
type=str,
|
|
||||||
default="48khz",
|
|
||||||
help="",
|
|
||||||
choices=["24khz", "32khz", "48khz"],
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--upload",
|
|
||||||
action="store_true",
|
|
||||||
help="Upload the weights to Hugging Face.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dtype",
|
|
||||||
type=str,
|
|
||||||
help="Data type to convert the model to.",
|
|
||||||
default="float32",
|
|
||||||
choices=["float32", "bfloat16", "float16"],
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
convert(upload=args.upload, model=args.model, dtype=args.dtype)
|
|
||||||
@@ -1,741 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
_lstm_kernel = mx.fast.metal_kernel(
|
|
||||||
name="lstm",
|
|
||||||
input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
|
|
||||||
output_names=["hidden_state", "cell_state"],
|
|
||||||
header="""
|
|
||||||
template <typename T>
|
|
||||||
T sigmoid(T x) {
|
|
||||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
|
||||||
return (x < 0) ? 1 - y : y;
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
source="""
|
|
||||||
uint b = thread_position_in_grid.x;
|
|
||||||
uint d = hidden_size * 4;
|
|
||||||
|
|
||||||
uint elem = b * d + thread_position_in_grid.y;
|
|
||||||
uint index = elem;
|
|
||||||
uint x_index = b * num_time_steps * d + time_step * d + index;
|
|
||||||
|
|
||||||
auto i = sigmoid(h_in[index] + x[x_index]);
|
|
||||||
index += hidden_size;
|
|
||||||
x_index += hidden_size;
|
|
||||||
auto f = sigmoid(h_in[index] + x[x_index]);
|
|
||||||
index += hidden_size;
|
|
||||||
x_index += hidden_size;
|
|
||||||
auto g = metal::precise::tanh(h_in[index] + x[x_index]);
|
|
||||||
index += hidden_size;
|
|
||||||
x_index += hidden_size;
|
|
||||||
auto o = sigmoid(h_in[index] + x[x_index]);
|
|
||||||
|
|
||||||
cell_state[elem] = f * cell[elem] + i * g;
|
|
||||||
hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def lstm_custom(x, h_in, cell, time_step):
|
|
||||||
assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
|
|
||||||
out_shape = cell.shape
|
|
||||||
return _lstm_kernel(
|
|
||||||
inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
|
|
||||||
output_shapes=[out_shape, out_shape],
|
|
||||||
output_dtypes=[h_in.dtype, h_in.dtype],
|
|
||||||
grid=(x.shape[0], h_in.size // 4, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LSTM(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
bias: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.Wx = mx.zeros((4 * hidden_size, input_size))
|
|
||||||
self.Wh = mx.zeros((4 * hidden_size, hidden_size))
|
|
||||||
self.bias = mx.zeros((4 * hidden_size,)) if bias else None
|
|
||||||
|
|
||||||
def __call__(self, x, hidden=None, cell=None):
|
|
||||||
if self.bias is not None:
|
|
||||||
x = mx.addmm(self.bias, x, self.Wx.T)
|
|
||||||
else:
|
|
||||||
x = x @ self.Wx.T
|
|
||||||
|
|
||||||
all_hidden = []
|
|
||||||
|
|
||||||
B = x.shape[0]
|
|
||||||
cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
|
|
||||||
for t in range(x.shape[-2]):
|
|
||||||
if hidden is None:
|
|
||||||
hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
|
|
||||||
else:
|
|
||||||
hidden = hidden @ self.Wh.T
|
|
||||||
hidden, cell = lstm_custom(x, hidden, cell, t)
|
|
||||||
all_hidden.append(hidden)
|
|
||||||
|
|
||||||
return mx.stack(all_hidden, axis=-2)
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecConv1d(nn.Module):
|
|
||||||
"""Conv1d with asymmetric or causal padding and normalization."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int,
|
|
||||||
stride: int = 1,
|
|
||||||
dilation: int = 1,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.causal = config.use_causal_conv
|
|
||||||
self.pad_mode = config.pad_mode
|
|
||||||
self.norm_type = config.norm_type
|
|
||||||
|
|
||||||
self.conv = nn.Conv1d(
|
|
||||||
in_channels, out_channels, kernel_size, stride, dilation=dilation
|
|
||||||
)
|
|
||||||
if self.norm_type == "time_group_norm":
|
|
||||||
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
|
||||||
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
# Effective kernel size with dilations.
|
|
||||||
self.kernel_size = (kernel_size - 1) * dilation + 1
|
|
||||||
|
|
||||||
self.padding_total = kernel_size - stride
|
|
||||||
|
|
||||||
def _get_extra_padding_for_conv1d(
|
|
||||||
self,
|
|
||||||
hidden_states: mx.array,
|
|
||||||
) -> mx.array:
|
|
||||||
length = hidden_states.shape[1]
|
|
||||||
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
|
|
||||||
n_frames = int(math.ceil(n_frames)) - 1
|
|
||||||
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
|
|
||||||
return ideal_length - length
|
|
||||||
|
|
||||||
def _pad1d(
|
|
||||||
self,
|
|
||||||
hidden_states: mx.array,
|
|
||||||
paddings: Tuple[int, int],
|
|
||||||
mode: str = "zero",
|
|
||||||
value: float = 0.0,
|
|
||||||
):
|
|
||||||
if mode != "reflect":
|
|
||||||
return mx.pad(
|
|
||||||
hidden_states, paddings, mode="constant", constant_values=value
|
|
||||||
)
|
|
||||||
|
|
||||||
length = hidden_states.shape[1]
|
|
||||||
prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
|
|
||||||
suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
|
|
||||||
return mx.concatenate([prefix, hidden_states, suffix], axis=1)
|
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
|
||||||
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
|
|
||||||
|
|
||||||
if self.causal:
|
|
||||||
# Left padding for causal
|
|
||||||
hidden_states = self._pad1d(
|
|
||||||
hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Asymmetric padding required for odd strides
|
|
||||||
padding_right = self.padding_total // 2
|
|
||||||
padding_left = self.padding_total - padding_right
|
|
||||||
hidden_states = self._pad1d(
|
|
||||||
hidden_states,
|
|
||||||
(padding_left, padding_right + extra_padding),
|
|
||||||
mode=self.pad_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.conv(hidden_states)
|
|
||||||
|
|
||||||
if self.norm_type == "time_group_norm":
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecConvTranspose1d(nn.Module):
|
|
||||||
"""ConvTranspose1d with asymmetric or causal padding and normalization."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int,
|
|
||||||
stride: int = 1,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.causal = config.use_causal_conv
|
|
||||||
self.trim_right_ratio = config.trim_right_ratio
|
|
||||||
self.norm_type = config.norm_type
|
|
||||||
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
|
||||||
if config.norm_type == "time_group_norm":
|
|
||||||
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
|
||||||
self.padding_total = kernel_size - stride
|
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
|
||||||
hidden_states = self.conv(hidden_states)
|
|
||||||
|
|
||||||
if self.norm_type == "time_group_norm":
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
if self.causal:
|
|
||||||
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
|
||||||
else:
|
|
||||||
padding_right = self.padding_total // 2
|
|
||||||
|
|
||||||
padding_left = self.padding_total - padding_right
|
|
||||||
|
|
||||||
end = hidden_states.shape[1] - padding_right
|
|
||||||
hidden_states = hidden_states[:, padding_left:end, :]
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecLSTM(nn.Module):
|
|
||||||
def __init__(self, config, dimension):
|
|
||||||
super().__init__()
|
|
||||||
self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
|
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
|
||||||
h = hidden_states
|
|
||||||
for lstm in self.lstm:
|
|
||||||
h = lstm(h)
|
|
||||||
return h + hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecResnetBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
Residual block from SEANet model as used by EnCodec.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, dim: int, dilations: List[int]):
|
|
||||||
super().__init__()
|
|
||||||
kernel_sizes = (config.residual_kernel_size, 1)
|
|
||||||
if len(kernel_sizes) != len(dilations):
|
|
||||||
raise ValueError("Number of kernel sizes should match number of dilations")
|
|
||||||
|
|
||||||
hidden = dim // config.compress
|
|
||||||
block = []
|
|
||||||
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
|
||||||
in_chs = dim if i == 0 else hidden
|
|
||||||
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
|
||||||
block += [nn.ELU()]
|
|
||||||
block += [
|
|
||||||
EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
|
|
||||||
]
|
|
||||||
self.block = block
|
|
||||||
|
|
||||||
if getattr(config, "use_conv_shortcut", True):
|
|
||||||
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
|
|
||||||
else:
|
|
||||||
self.shortcut = nn.Identity()
|
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
|
||||||
residual = hidden_states
|
|
||||||
for layer in self.block:
|
|
||||||
hidden_states = layer(hidden_states)
|
|
||||||
|
|
||||||
return self.shortcut(residual) + hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecEncoder(nn.Module):
|
|
||||||
"""SEANet encoder as used by EnCodec."""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
model = [
|
|
||||||
EncodecConv1d(
|
|
||||||
config, config.audio_channels, config.num_filters, config.kernel_size
|
|
||||||
)
|
|
||||||
]
|
|
||||||
scaling = 1
|
|
||||||
|
|
||||||
for ratio in reversed(config.upsampling_ratios):
|
|
||||||
current_scale = scaling * config.num_filters
|
|
||||||
for j in range(config.num_residual_layers):
|
|
||||||
model += [
|
|
||||||
EncodecResnetBlock(
|
|
||||||
config, current_scale, [config.dilation_growth_rate**j, 1]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
model += [nn.ELU()]
|
|
||||||
model += [
|
|
||||||
EncodecConv1d(
|
|
||||||
config,
|
|
||||||
current_scale,
|
|
||||||
current_scale * 2,
|
|
||||||
kernel_size=ratio * 2,
|
|
||||||
stride=ratio,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
scaling *= 2
|
|
||||||
|
|
||||||
model += [EncodecLSTM(config, scaling * config.num_filters)]
|
|
||||||
model += [nn.ELU()]
|
|
||||||
model += [
|
|
||||||
EncodecConv1d(
|
|
||||||
config,
|
|
||||||
scaling * config.num_filters,
|
|
||||||
config.hidden_size,
|
|
||||||
config.last_kernel_size,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.layers = model
|
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
|
||||||
for layer in self.layers:
|
|
||||||
hidden_states = layer(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecDecoder(nn.Module):
|
|
||||||
"""SEANet decoder as used by EnCodec."""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
scaling = int(2 ** len(config.upsampling_ratios))
|
|
||||||
model = [
|
|
||||||
EncodecConv1d(
|
|
||||||
config,
|
|
||||||
config.hidden_size,
|
|
||||||
scaling * config.num_filters,
|
|
||||||
config.kernel_size,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
model += [EncodecLSTM(config, scaling * config.num_filters)]
|
|
||||||
|
|
||||||
for ratio in config.upsampling_ratios:
|
|
||||||
current_scale = scaling * config.num_filters
|
|
||||||
model += [nn.ELU()]
|
|
||||||
model += [
|
|
||||||
EncodecConvTranspose1d(
|
|
||||||
config,
|
|
||||||
current_scale,
|
|
||||||
current_scale // 2,
|
|
||||||
kernel_size=ratio * 2,
|
|
||||||
stride=ratio,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
for j in range(config.num_residual_layers):
|
|
||||||
model += [
|
|
||||||
EncodecResnetBlock(
|
|
||||||
config, current_scale // 2, (config.dilation_growth_rate**j, 1)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
scaling //= 2
|
|
||||||
|
|
||||||
model += [nn.ELU()]
|
|
||||||
model += [
|
|
||||||
EncodecConv1d(
|
|
||||||
config,
|
|
||||||
config.num_filters,
|
|
||||||
config.audio_channels,
|
|
||||||
config.last_kernel_size,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
self.layers = model
|
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
|
||||||
for layer in self.layers:
|
|
||||||
hidden_states = layer(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecEuclideanCodebook(nn.Module):
|
|
||||||
"""Codebook with Euclidean distance."""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
|
|
||||||
|
|
||||||
def quantize(self, hidden_states):
|
|
||||||
embed = self.embed.T
|
|
||||||
scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
|
|
||||||
dist = -(
|
|
||||||
scaled_states
|
|
||||||
- 2 * hidden_states @ embed
|
|
||||||
+ embed.square().sum(axis=0, keepdims=True)
|
|
||||||
)
|
|
||||||
embed_ind = dist.argmax(axis=-1)
|
|
||||||
return embed_ind
|
|
||||||
|
|
||||||
def encode(self, hidden_states):
|
|
||||||
shape = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.reshape((-1, shape[-1]))
|
|
||||||
embed_ind = self.quantize(hidden_states)
|
|
||||||
embed_ind = embed_ind.reshape(*shape[:-1])
|
|
||||||
return embed_ind
|
|
||||||
|
|
||||||
def decode(self, embed_ind):
|
|
||||||
return self.embed[embed_ind]
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecVectorQuantization(nn.Module):
|
|
||||||
"""
|
|
||||||
Vector quantization implementation. Currently supports only euclidean distance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.codebook = EncodecEuclideanCodebook(config)
|
|
||||||
|
|
||||||
def encode(self, hidden_states):
|
|
||||||
return self.codebook.encode(hidden_states)
|
|
||||||
|
|
||||||
def decode(self, embed_ind):
|
|
||||||
return self.codebook.decode(embed_ind)
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecResidualVectorQuantizer(nn.Module):
|
|
||||||
"""Residual Vector Quantizer."""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.codebook_size = config.codebook_size
|
|
||||||
|
|
||||||
hop_length = np.prod(config.upsampling_ratios)
|
|
||||||
self.frame_rate = math.ceil(config.sampling_rate / hop_length)
|
|
||||||
self.num_quantizers = int(
|
|
||||||
1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
|
|
||||||
)
|
|
||||||
self.layers = [
|
|
||||||
EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_num_quantizers_for_bandwidth(
|
|
||||||
self, bandwidth: Optional[float] = None
|
|
||||||
) -> int:
|
|
||||||
"""Return num_quantizers based on specified target bandwidth."""
|
|
||||||
bw_per_q = math.log2(self.codebook_size) * self.frame_rate
|
|
||||||
num_quantizers = self.num_quantizers
|
|
||||||
if bandwidth is not None and bandwidth > 0.0:
|
|
||||||
num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
|
|
||||||
return num_quantizers
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self, embeddings: mx.array, bandwidth: Optional[float] = None
|
|
||||||
) -> mx.array:
|
|
||||||
"""
|
|
||||||
Encode a given input array with the specified frame rate at the given
|
|
||||||
bandwidth. The RVQ encode method sets the appropriate number of
|
|
||||||
quantizers to use and returns indices for each quantizer.
|
|
||||||
"""
|
|
||||||
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
|
|
||||||
residual = embeddings
|
|
||||||
all_indices = []
|
|
||||||
for layer in self.layers[:num_quantizers]:
|
|
||||||
indices = layer.encode(residual)
|
|
||||||
quantized = layer.decode(indices)
|
|
||||||
residual = residual - quantized
|
|
||||||
all_indices.append(indices)
|
|
||||||
out_indices = mx.stack(all_indices, axis=1)
|
|
||||||
return out_indices
|
|
||||||
|
|
||||||
def decode(self, codes: mx.array) -> mx.array:
|
|
||||||
"""Decode the given codes to the quantized representation."""
|
|
||||||
quantized_out = None
|
|
||||||
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
|
|
||||||
layer = self.layers[i]
|
|
||||||
quantized = layer.decode(indices.squeeze(1))
|
|
||||||
if quantized_out is None:
|
|
||||||
quantized_out = quantized
|
|
||||||
else:
|
|
||||||
quantized_out = quantized + quantized_out
|
|
||||||
return quantized_out
|
|
||||||
|
|
||||||
|
|
||||||
class EncodecModel(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.encoder = EncodecEncoder(config)
|
|
||||||
self.decoder = EncodecDecoder(config)
|
|
||||||
self.quantizer = EncodecResidualVectorQuantizer(config)
|
|
||||||
|
|
||||||
def _encode_frame(
|
|
||||||
self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
|
|
||||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
|
||||||
"""
|
|
||||||
Encodes the given input using the underlying VQVAE.
|
|
||||||
"""
|
|
||||||
length = input_values.shape[1]
|
|
||||||
duration = length / self.config.sampling_rate
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.config.chunk_length_s is not None
|
|
||||||
and duration > 1e-5 + self.config.chunk_length_s
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
scale = None
|
|
||||||
if self.config.normalize:
|
|
||||||
# if the padding is non zero
|
|
||||||
input_values = input_values * padding_mask[..., None]
|
|
||||||
mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
|
|
||||||
scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
|
|
||||||
input_values = input_values / scale
|
|
||||||
|
|
||||||
embeddings = self.encoder(input_values)
|
|
||||||
codes = self.quantizer.encode(embeddings, bandwidth)
|
|
||||||
return codes, scale
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
input_values: mx.array,
|
|
||||||
padding_mask: mx.array = None,
|
|
||||||
bandwidth: Optional[float] = None,
|
|
||||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
|
||||||
"""
|
|
||||||
Encodes the input audio waveform into discrete codes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_values (mx.array): The input audio waveform with shape
|
|
||||||
``(batch_size, channels, sequence_length)``.
|
|
||||||
padding_mask (mx.array): Padding mask used to pad the ``input_values``.
|
|
||||||
bandwidth (float, optional): The target bandwidth. Must be one of
|
|
||||||
``config.target_bandwidths``. If ``None``, uses the smallest
|
|
||||||
possible bandwidth. bandwidth is represented as a thousandth of
|
|
||||||
what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of frames containing the discrete encoded codes for the
|
|
||||||
input audio waveform, along with rescaling factors for each chunk
|
|
||||||
when ``config.normalize==True``. Each frame is a tuple ``(codebook,
|
|
||||||
scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
|
|
||||||
frames)``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if bandwidth is None:
|
|
||||||
bandwidth = self.config.target_bandwidths[0]
|
|
||||||
if bandwidth not in self.config.target_bandwidths:
|
|
||||||
raise ValueError(
|
|
||||||
f"This model doesn't support the bandwidth {bandwidth}. "
|
|
||||||
f"Select one of {self.config.target_bandwidths}."
|
|
||||||
)
|
|
||||||
|
|
||||||
_, input_length, channels = input_values.shape
|
|
||||||
|
|
||||||
if channels < 1 or channels > 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of audio channels must be 1 or 2, but got {channels}"
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk_length = self.chunk_length
|
|
||||||
if chunk_length is None:
|
|
||||||
chunk_length = input_length
|
|
||||||
stride = input_length
|
|
||||||
else:
|
|
||||||
stride = self.chunk_stride
|
|
||||||
|
|
||||||
if padding_mask is None:
|
|
||||||
padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
|
|
||||||
encoded_frames = []
|
|
||||||
scales = []
|
|
||||||
|
|
||||||
step = chunk_length - stride
|
|
||||||
if (input_length % stride) != step:
|
|
||||||
raise ValueError(
|
|
||||||
"The input length is not properly padded for batched chunked "
|
|
||||||
"encoding. Make sure to pad the input correctly."
|
|
||||||
)
|
|
||||||
|
|
||||||
for offset in range(0, input_length - step, stride):
|
|
||||||
mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
|
|
||||||
frame = input_values[:, offset : offset + chunk_length]
|
|
||||||
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
|
|
||||||
encoded_frames.append(encoded_frame)
|
|
||||||
scales.append(scale)
|
|
||||||
|
|
||||||
encoded_frames = mx.stack(encoded_frames)
|
|
||||||
|
|
||||||
return (encoded_frames, scales)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _linear_overlap_add(frames: List[mx.array], stride: int):
|
|
||||||
if len(frames) == 0:
|
|
||||||
raise ValueError("`frames` cannot be an empty list.")
|
|
||||||
|
|
||||||
dtype = frames[0].dtype
|
|
||||||
N, frame_length, C = frames[0].shape
|
|
||||||
total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
|
|
||||||
|
|
||||||
time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
|
|
||||||
weight = 0.5 - (time_vec - 0.5).abs()
|
|
||||||
|
|
||||||
weight = weight[:, None]
|
|
||||||
sum_weight = mx.zeros((total_size, 1), dtype=dtype)
|
|
||||||
out = mx.zeros((N, total_size, C), dtype=dtype)
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
for frame in frames:
|
|
||||||
frame_length = frame.shape[1]
|
|
||||||
out[:, offset : offset + frame_length] += weight[:frame_length] * frame
|
|
||||||
sum_weight[offset : offset + frame_length] += weight[:frame_length]
|
|
||||||
offset += stride
|
|
||||||
|
|
||||||
return out / sum_weight
|
|
||||||
|
|
||||||
def _decode_frame(
|
|
||||||
self, codes: mx.array, scale: Optional[mx.array] = None
|
|
||||||
) -> mx.array:
|
|
||||||
embeddings = self.quantizer.decode(codes)
|
|
||||||
outputs = self.decoder(embeddings)
|
|
||||||
if scale is not None:
|
|
||||||
outputs = outputs * scale
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def channels(self):
|
|
||||||
return self.config.audio_channels
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sampling_rate(self):
|
|
||||||
return self.config.sampling_rate
|
|
||||||
|
|
||||||
@property
|
|
||||||
def chunk_length(self):
|
|
||||||
if self.config.chunk_length_s is None:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return int(self.config.chunk_length_s * self.config.sampling_rate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def chunk_stride(self):
|
|
||||||
if self.config.chunk_length_s is None or self.config.overlap is None:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
|
|
||||||
|
|
||||||
def decode(
|
|
||||||
self,
|
|
||||||
audio_codes: mx.array,
|
|
||||||
audio_scales: Union[mx.array, List[mx.array]],
|
|
||||||
padding_mask: Optional[mx.array] = None,
|
|
||||||
) -> Tuple[mx.array, mx.array]:
|
|
||||||
"""
|
|
||||||
Decodes the given frames into an output audio waveform.
|
|
||||||
|
|
||||||
Note that the output might be a bit bigger than the input. In that
|
|
||||||
case, any extra steps at the end should be trimmed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio_codes (mx.array): Discret code embeddings of shape
|
|
||||||
``(batch_size, nb_chunks, chunk_length)``.
|
|
||||||
audio_scales (mx.array): Scaling factor for each input.
|
|
||||||
padding_mask (mx.array): Padding mask.
|
|
||||||
"""
|
|
||||||
chunk_length = self.chunk_length
|
|
||||||
if chunk_length is None:
|
|
||||||
if audio_codes.shape[1] != 1:
|
|
||||||
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
|
||||||
audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
|
|
||||||
else:
|
|
||||||
decoded_frames = []
|
|
||||||
|
|
||||||
for frame, scale in zip(audio_codes, audio_scales):
|
|
||||||
frames = self._decode_frame(frame, scale)
|
|
||||||
decoded_frames.append(frames)
|
|
||||||
|
|
||||||
audio_values = self._linear_overlap_add(
|
|
||||||
decoded_frames, self.chunk_stride or 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# truncate based on padding mask
|
|
||||||
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
|
|
||||||
audio_values = audio_values[:, : padding_mask.shape[1]]
|
|
||||||
return audio_values
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, path_or_repo: str):
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
path = Path(path_or_repo)
|
|
||||||
if not path.exists():
|
|
||||||
path = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=path_or_repo,
|
|
||||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(path / "config.json", "r") as f:
|
|
||||||
config = SimpleNamespace(**json.load(f))
|
|
||||||
|
|
||||||
model = EncodecModel(config)
|
|
||||||
model.load_weights(str(path / "model.safetensors"))
|
|
||||||
processor = functools.partial(
|
|
||||||
preprocess_audio,
|
|
||||||
sampling_rate=config.sampling_rate,
|
|
||||||
chunk_length=model.chunk_length,
|
|
||||||
chunk_stride=model.chunk_stride,
|
|
||||||
)
|
|
||||||
mx.eval(model)
|
|
||||||
return model, processor
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_audio(
|
|
||||||
raw_audio: Union[mx.array, List[mx.array]],
|
|
||||||
sampling_rate: int = 24000,
|
|
||||||
chunk_length: Optional[int] = None,
|
|
||||||
chunk_stride: Optional[int] = None,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Prepare inputs for the EnCodec model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
|
||||||
sequences to be processed.
|
|
||||||
sampling_rate (int): The sampling rate at which the audio waveform
|
|
||||||
should be digitalized.
|
|
||||||
chunk_length (int, optional): The model's chunk length.
|
|
||||||
chunk_stride (int, optional): The model's chunk stride.
|
|
||||||
"""
|
|
||||||
if not isinstance(raw_audio, list):
|
|
||||||
raw_audio = [raw_audio]
|
|
||||||
|
|
||||||
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
|
||||||
|
|
||||||
max_length = max(array.shape[0] for array in raw_audio)
|
|
||||||
if chunk_length is not None:
|
|
||||||
max_length += chunk_length - (max_length % chunk_stride)
|
|
||||||
|
|
||||||
inputs = []
|
|
||||||
masks = []
|
|
||||||
for x in raw_audio:
|
|
||||||
length = x.shape[0]
|
|
||||||
mask = mx.ones((length,), dtype=mx.bool_)
|
|
||||||
difference = max_length - length
|
|
||||||
if difference > 0:
|
|
||||||
mask = mx.pad(mask, (0, difference))
|
|
||||||
x = mx.pad(x, ((0, difference), (0, 0)))
|
|
||||||
inputs.append(x)
|
|
||||||
masks.append(mask)
|
|
||||||
return mx.stack(inputs), mx.stack(masks)
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from utils import load_audio, save_audio
|
|
||||||
|
|
||||||
from encodec import EncodecModel
|
|
||||||
|
|
||||||
# Load the 48 KHz model and preprocessor.
|
|
||||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
|
||||||
|
|
||||||
# Load an audio file
|
|
||||||
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
|
||||||
|
|
||||||
# Preprocess the audio (this can also be a list of arrays for batched
|
|
||||||
# processing).
|
|
||||||
feats, mask = processor(audio)
|
|
||||||
|
|
||||||
|
|
||||||
# Encode at the given bandwidth. A lower bandwidth results in more
|
|
||||||
# compression but lower reconstruction quality.
|
|
||||||
@mx.compile
|
|
||||||
def encode(feats, mask):
|
|
||||||
return model.encode(feats, mask, bandwidth=3)
|
|
||||||
|
|
||||||
|
|
||||||
# Decode to reconstruct the audio
|
|
||||||
@mx.compile
|
|
||||||
def decode(codes, scales, mask):
|
|
||||||
return model.decode(codes, scales, mask)
|
|
||||||
|
|
||||||
|
|
||||||
codes, scales = encode(feats, mask)
|
|
||||||
reconstructed = decode(codes, scales, mask)
|
|
||||||
|
|
||||||
# Trim any padding:
|
|
||||||
reconstructed = reconstructed[0, : len(audio)]
|
|
||||||
|
|
||||||
# Save the audio as a wave file
|
|
||||||
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
mlx>=0.18
|
|
||||||
numpy
|
|
||||||
huggingface_hub
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from transformers import AutoProcessor
|
|
||||||
from transformers import EncodecModel as PTEncodecModel
|
|
||||||
|
|
||||||
from encodec import EncodecModel, preprocess_audio
|
|
||||||
|
|
||||||
|
|
||||||
def compare_processors():
|
|
||||||
np.random.seed(0)
|
|
||||||
audio_length = 95500
|
|
||||||
audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)
|
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
|
||||||
|
|
||||||
pt_inputs = processor(
|
|
||||||
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
|
|
||||||
)
|
|
||||||
mx_inputs = preprocess_audio(
|
|
||||||
mx.array(audio).T,
|
|
||||||
processor.sampling_rate,
|
|
||||||
processor.chunk_length,
|
|
||||||
processor.chunk_stride,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
|
|
||||||
assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])
|
|
||||||
|
|
||||||
|
|
||||||
def compare_models():
|
|
||||||
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
|
|
||||||
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
|
||||||
|
|
||||||
np.random.seed(0)
|
|
||||||
audio_length = 190560
|
|
||||||
audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
|
|
||||||
mask = np.ones((1, audio_length), dtype=np.int32)
|
|
||||||
pt_encoded = pt_model.encode(
|
|
||||||
torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
|
|
||||||
)
|
|
||||||
mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
|
|
||||||
pt_codes = pt_encoded.audio_codes.numpy()
|
|
||||||
mx_codes = mx_encoded[0]
|
|
||||||
assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
|
|
||||||
|
|
||||||
for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
|
|
||||||
if mx_scale is not None:
|
|
||||||
pt_scale = pt_scale.numpy()
|
|
||||||
assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)
|
|
||||||
|
|
||||||
pt_audio = pt_model.decode(
|
|
||||||
pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
|
|
||||||
)
|
|
||||||
pt_audio = pt_audio[0].squeeze().T.detach().numpy()
|
|
||||||
mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
|
|
||||||
mx_audio = mx_audio.squeeze()
|
|
||||||
assert np.allclose(
|
|
||||||
pt_audio, mx_audio, atol=1e-4, rtol=1e-4
|
|
||||||
), "Decoding audio mismatch"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
compare_processors()
|
|
||||||
compare_models()
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
|
||||||
"""
|
|
||||||
Save audio to a wave (.wav) file.
|
|
||||||
"""
|
|
||||||
from scipy.io.wavfile import write
|
|
||||||
|
|
||||||
audio = (audio * 32767).astype(mx.int16)
|
|
||||||
write(file, sampling_rate, np.array(audio))
|
|
||||||
|
|
||||||
|
|
||||||
def load_audio(file: str, sampling_rate: int, channels: int):
|
|
||||||
"""
|
|
||||||
Read audio into an mx.array, resampling if necessary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (str): The audio file to open.
|
|
||||||
sampling_rate (int): The sample rate to resample the audio at if needed.
|
|
||||||
channels (int): The number of audio channels.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An mx.array containing the audio waveform in float32.
|
|
||||||
"""
|
|
||||||
from subprocess import CalledProcessError, run
|
|
||||||
|
|
||||||
# This launches a subprocess to decode audio while down-mixing
|
|
||||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
|
||||||
# fmt: off
|
|
||||||
cmd = [
|
|
||||||
"ffmpeg",
|
|
||||||
"-nostdin",
|
|
||||||
"-threads", "0",
|
|
||||||
"-i", file,
|
|
||||||
"-f", "s16le",
|
|
||||||
"-ac", str(channels),
|
|
||||||
"-acodec", "pcm_s16le",
|
|
||||||
"-ar", str(sampling_rate),
|
|
||||||
"-"
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
try:
|
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
|
||||||
except CalledProcessError as e:
|
|
||||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
|
||||||
|
|
||||||
out = mx.array(np.frombuffer(out, np.int16))
|
|
||||||
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
|
||||||
281
flux/README.md
281
flux/README.md
@@ -1,281 +0,0 @@
|
|||||||
FLUX
|
|
||||||
====
|
|
||||||
|
|
||||||
FLUX implementation in MLX. The implementation is ported directly from
|
|
||||||
[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux)
|
|
||||||
and the model weights are downloaded directly from the Hugging Face Hub.
|
|
||||||
|
|
||||||
The goal of this example is to be clean, educational and to allow for
|
|
||||||
experimentation with finetuning FLUX models as well as adding extra
|
|
||||||
functionality such as in-/outpainting, guidance with custom losses etc.
|
|
||||||
|
|
||||||

|
|
||||||
*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of
|
|
||||||
tron emanating futuristic technology with the word "MLX" in the center with
|
|
||||||
capital red letters.'*
|
|
||||||
|
|
||||||
Installation
|
|
||||||
------------
|
|
||||||
|
|
||||||
The dependencies are minimal, namely:
|
|
||||||
|
|
||||||
- `huggingface-hub` to download the checkpoints.
|
|
||||||
- `regex` for the tokenization
|
|
||||||
- `tqdm`, `PIL`, and `numpy` for the scripts
|
|
||||||
- `sentencepiece` for the T5 tokenizer
|
|
||||||
- `datasets` for using an HF dataset directly
|
|
||||||
|
|
||||||
You can install all of the above with the `requirements.txt` as follows:
|
|
||||||
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
|
|
||||||
Usage
|
|
||||||
---------
|
|
||||||
|
|
||||||
You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python txt2image.py --model schnell \
|
|
||||||
--n-images 1 \
|
|
||||||
--image-size 256x512 \
|
|
||||||
--verbose \
|
|
||||||
'A photo of an astronaut riding a horse on Mars.'
|
|
||||||
```
|
|
||||||
|
|
||||||
For more parameters, please use the `--help` command to view.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python txt2image.py --help
|
|
||||||
```
|
|
||||||
|
|
||||||
Inference
|
|
||||||
---------
|
|
||||||
|
|
||||||
Inference in this example is similar to the stable diffusion example. The
|
|
||||||
classes to get you started are `FluxPipeline` from the `flux` module.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import mlx.core as mx
|
|
||||||
from flux import FluxPipeline
|
|
||||||
|
|
||||||
# This will download all the weights from HF hub
|
|
||||||
flux = FluxPipeline("flux-schnell")
|
|
||||||
|
|
||||||
# Make a generator that returns the latent variables from the reverse diffusion
|
|
||||||
# process
|
|
||||||
latent_generator = flux.generate_latents(
|
|
||||||
"A photo of an astronaut riding a horse on Mars",
|
|
||||||
num_steps=4,
|
|
||||||
latent_size=(32, 64), # 256x512 image
|
|
||||||
)
|
|
||||||
|
|
||||||
# The first return value of the generator contains the conditioning and the
|
|
||||||
# random noise at the beginning of the diffusion process.
|
|
||||||
conditioning = next(latent_generator)
|
|
||||||
(
|
|
||||||
x_T, # The initial noise
|
|
||||||
x_positions, # The integer positions used for image positional encoding
|
|
||||||
t5_conditioning, # The T5 features from the text prompt
|
|
||||||
t5_positions, # Integer positions for text (normally all 0s)
|
|
||||||
clip_conditioning, # The clip text features from the text prompt
|
|
||||||
) = conditioning
|
|
||||||
|
|
||||||
# Returning the conditioning as the first output from the generator allows us
|
|
||||||
# to unload T5 and clip before running the diffusion transformer.
|
|
||||||
mx.eval(conditioning)
|
|
||||||
|
|
||||||
# Evaluate each diffusion step
|
|
||||||
for x_t in latent_generator:
|
|
||||||
mx.eval(x_t)
|
|
||||||
|
|
||||||
# Note that we need to pass the latent size because it is collapsed and
|
|
||||||
# patchified in x_t and we need to unwrap it.
|
|
||||||
img = flux.decode(x_t, latent_size=(32, 64))
|
|
||||||
```
|
|
||||||
|
|
||||||
The above are essentially the implementation of the `txt2image.py` script
|
|
||||||
except for some additional logic to quantize and/or load trained adapters. One
|
|
||||||
can use the script as follows:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python txt2image.py \
|
|
||||||
--n-images 4 \
|
|
||||||
--n-rows 2 \
|
|
||||||
--image-size 256x512 \
|
|
||||||
'A photo of an astronaut riding a horse on Mars.'
|
|
||||||
```
|
|
||||||
|
|
||||||
### Experimental Options
|
|
||||||
|
|
||||||
FLUX pads the prompt to a specific size of 512 tokens for the dev model and
|
|
||||||
256 for the schnell model. Not applying padding results in faster generation
|
|
||||||
but it is not clear how it may affect the generated images. To enable that
|
|
||||||
option in this example pass `--no-t5-padding` to the `txt2image.py` script or
|
|
||||||
instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`.
|
|
||||||
|
|
||||||
Finetuning
|
|
||||||
----------
|
|
||||||
|
|
||||||
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
|
|
||||||
but ymmv) on a provided image dataset. The dataset folder must have an
|
|
||||||
`train.jsonl` file with the following format:
|
|
||||||
|
|
||||||
```jsonl
|
|
||||||
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
|
|
||||||
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
The training script by default trains for 600 iterations with a batch size of
|
|
||||||
1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py
|
|
||||||
--help` for the list of hyperparameters you can tune.
|
|
||||||
|
|
||||||
> [!Note]
|
|
||||||
> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and
|
|
||||||
> should reduce this number significantly.
|
|
||||||
|
|
||||||
### Training Example
|
|
||||||
|
|
||||||
This is a step-by-step finetuning example. We will be using the data from
|
|
||||||
[https://github.com/google/dreambooth](https://github.com/google/dreambooth).
|
|
||||||
In particular, we will use `dog6` which is a popular example for showcasing
|
|
||||||
dreambooth [^1].
|
|
||||||
|
|
||||||
The training images are the following 5 images [^2]:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
We start by making the following `train.jsonl` file and placing it in the same
|
|
||||||
folder as the images.
|
|
||||||
|
|
||||||
```jsonl
|
|
||||||
{"image": "00.jpg", "prompt": "A photo of sks dog"}
|
|
||||||
{"image": "01.jpg", "prompt": "A photo of sks dog"}
|
|
||||||
{"image": "02.jpg", "prompt": "A photo of sks dog"}
|
|
||||||
{"image": "03.jpg", "prompt": "A photo of sks dog"}
|
|
||||||
{"image": "04.jpg", "prompt": "A photo of sks dog"}
|
|
||||||
```
|
|
||||||
|
|
||||||
Subsequently we finetune FLUX using the following command:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python dreambooth.py \
|
|
||||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
|
||||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
|
||||||
--lora-rank 4 --grad-accumulate 8 \
|
|
||||||
path/to/dreambooth/dataset/dog6
|
|
||||||
```
|
|
||||||
|
|
||||||
Or you can directly use the pre-processed Hugging Face dataset
|
|
||||||
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
|
|
||||||
for fine-tuning.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python dreambooth.py \
|
|
||||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
|
||||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
|
||||||
--lora-rank 4 --grad-accumulate 8 \
|
|
||||||
mlx-community/dreambooth-dog6
|
|
||||||
```
|
|
||||||
|
|
||||||
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
|
|
||||||
bit more than 1 hour.
|
|
||||||
|
|
||||||
### Using the Adapter
|
|
||||||
|
|
||||||
The adapters are saved in `mlx_output` and can be used directly by the
|
|
||||||
`txt2image.py` script. For instance,
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
|
|
||||||
--adapter mlx_output/final_adapters.safetensors \
|
|
||||||
--fuse-adapter \
|
|
||||||
--no-t5-padding \
|
|
||||||
'A photo of an sks dog lying on the sand at a beach in Greece'
|
|
||||||
```
|
|
||||||
|
|
||||||
generates an image that looks like the following,
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
and of course we can pass `--image-size 512x1024` to get larger images with
|
|
||||||
different aspect ratios,
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
The arguments that are relevant to the adapters are of course `--adapter` and
|
|
||||||
`--fuse-adapter`. The first defines the path to an adapter to apply to the
|
|
||||||
model and the second fuses the adapter back into the model to get a bit more
|
|
||||||
speed during generation.
|
|
||||||
|
|
||||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
|
|
||||||
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
|
|
||||||
|
|
||||||
|
|
||||||
Distributed Computation
|
|
||||||
------------------------
|
|
||||||
|
|
||||||
The FLUX example supports distributed computation during both generation and
|
|
||||||
training. See the [distributed communication
|
|
||||||
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
|
|
||||||
for information on how to set-up MLX for distributed communication. The rest of
|
|
||||||
this section assumes you can launch distributed MLX programs using `mlx.launch
|
|
||||||
--hostfile hostfile.json`.
|
|
||||||
|
|
||||||
### Distributed Finetuning
|
|
||||||
|
|
||||||
Distributed finetuning scales very well with FLUX and all one has to do is
|
|
||||||
adjust the gradient accumulation and training iterations so that the batch
|
|
||||||
size remains the same. For instance, to replicate the following training
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python dreambooth.py \
|
|
||||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
|
||||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
|
||||||
--lora-rank 4 --grad-accumulate 8 \
|
|
||||||
mlx-community/dreambooth-dog6
|
|
||||||
```
|
|
||||||
|
|
||||||
On 4 machines we simply run
|
|
||||||
|
|
||||||
```shell
|
|
||||||
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
|
|
||||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
|
||||||
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
|
|
||||||
--lora-rank 4 --grad-accumulate 2 \
|
|
||||||
mlx-community/dreambooth-dog6
|
|
||||||
```
|
|
||||||
|
|
||||||
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
|
|
||||||
|
|
||||||
### Distributed Inference
|
|
||||||
|
|
||||||
Distributed inference can be divided in two different approaches. The first
|
|
||||||
approach is the data-parallel approach, where each node generates its own
|
|
||||||
images and shares them at the end. The second approach is the model-parallel
|
|
||||||
approach where the model is shared across the nodes and they collaboratively
|
|
||||||
generate the images.
|
|
||||||
|
|
||||||
The `txt2image.py` script will attempt to choose the best approach depending on
|
|
||||||
how many images are being generated across the nodes. The model-parallel
|
|
||||||
approach can be forced by passing the argument `--force-shard`.
|
|
||||||
|
|
||||||
For better performance in the model-parallel approach we suggest that you use a
|
|
||||||
[thunderbolt
|
|
||||||
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
|
|
||||||
|
|
||||||
All you have to do once again is use `mlx.launch` as follows
|
|
||||||
|
|
||||||
```shell
|
|
||||||
mlx.launch --verbose --hostfile hostfile.json -- \
|
|
||||||
python txt2image.py --model schnell \
|
|
||||||
--n-images 8 \
|
|
||||||
--image-size 512x512 \
|
|
||||||
--verbose \
|
|
||||||
'A photo of an astronaut riding a horse on Mars'
|
|
||||||
```
|
|
||||||
|
|
||||||
for model-parallel generation you may want to also pass `--env
|
|
||||||
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
|
|
||||||
reduces the CPU/GPU synchronization overhead.
|
|
||||||
@@ -1,292 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx.optimizers as optim
|
|
||||||
import numpy as np
|
|
||||||
from mlx.nn.utils import average_gradients
|
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from flux import FluxPipeline, Trainer, load_dataset, save_config
|
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
|
||||||
"""Generate images to monitor the progress of the finetuning."""
|
|
||||||
out_dir = Path(args.output_dir)
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
out_file = out_dir / f"{iteration:07d}_progress.png"
|
|
||||||
print(f"Generating {str(out_file)}", flush=True)
|
|
||||||
|
|
||||||
# Generate some images and arrange them in a grid
|
|
||||||
n_rows = 2
|
|
||||||
n_images = 4
|
|
||||||
x = flux.generate_images(
|
|
||||||
args.progress_prompt,
|
|
||||||
n_images,
|
|
||||||
args.progress_steps,
|
|
||||||
)
|
|
||||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
|
||||||
B, H, W, C = x.shape
|
|
||||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
|
||||||
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
|
||||||
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
|
|
||||||
x = (x * 255).astype(mx.uint8)
|
|
||||||
|
|
||||||
# Save them to disc
|
|
||||||
im = Image.fromarray(np.array(x))
|
|
||||||
im.save(out_file)
|
|
||||||
|
|
||||||
|
|
||||||
def save_adapters(adapter_name, flux, args):
|
|
||||||
out_dir = Path(args.output_dir)
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
out_file = out_dir / adapter_name
|
|
||||||
print(f"Saving {str(out_file)}")
|
|
||||||
|
|
||||||
mx.save_safetensors(
|
|
||||||
str(out_file),
|
|
||||||
dict(tree_flatten(flux.flow.trainable_parameters())),
|
|
||||||
metadata={
|
|
||||||
"lora_rank": str(args.lora_rank),
|
|
||||||
"lora_blocks": str(args.lora_blocks),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
|
||||||
"""Set up and return the argument parser."""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Finetune Flux to generate images with a specific subject"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
default="dev",
|
|
||||||
choices=[
|
|
||||||
"dev",
|
|
||||||
"schnell",
|
|
||||||
],
|
|
||||||
help="Which flux model to train",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--guidance", type=float, default=4.0, help="The guidance factor to use."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--iterations",
|
|
||||||
type=int,
|
|
||||||
default=600,
|
|
||||||
help="How many iterations to train for",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The batch size to use when training the stable diffusion model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--resolution",
|
|
||||||
type=lambda x: tuple(map(int, x.split("x"))),
|
|
||||||
default=(512, 512),
|
|
||||||
help="The resolution of the training images",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-augmentations",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Augment the images by random cropping and panning",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--progress-prompt",
|
|
||||||
required=True,
|
|
||||||
help="Use this prompt when generating images for evaluation",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--progress-steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="Use this many steps when generating images for evaluation",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--progress-every",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="Generate images every PROGRESS_EVERY steps",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpoint-every",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="Save the model every CHECKPOINT_EVERY steps",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora-blocks",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="Train the last LORA_BLOCKS transformer blocks",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora-rank", type=int, default=8, help="LoRA rank for finetuning"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--grad-accumulate",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="Accumulate gradients for that many iterations before applying them",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("dataset")
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = setup_arg_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
output_path = Path(args.output_dir)
|
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_config(vars(args), output_path / "adapter_config.json")
|
|
||||||
|
|
||||||
# Load the model and set it up for LoRA training. We use the same random
|
|
||||||
# state when creating the LoRA layers so all workers will have the same
|
|
||||||
# initial weights.
|
|
||||||
mx.random.seed(0x0F0F0F0F)
|
|
||||||
flux = FluxPipeline("flux-" + args.model)
|
|
||||||
flux.flow.freeze()
|
|
||||||
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
|
||||||
|
|
||||||
# Reset the seed to a different seed per worker if we are in distributed
|
|
||||||
# mode so that each worker is working on different data, diffusion step and
|
|
||||||
# random noise.
|
|
||||||
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
|
|
||||||
|
|
||||||
# Report how many parameters we are training
|
|
||||||
trainable_params = tree_reduce(
|
|
||||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
|
||||||
)
|
|
||||||
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
|
|
||||||
|
|
||||||
# Set up the optimizer and training steps. The steps are a bit verbose to
|
|
||||||
# support gradient accumulation together with compilation.
|
|
||||||
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
|
|
||||||
cosine = optim.cosine_decay(
|
|
||||||
args.learning_rate, args.iterations // args.grad_accumulate
|
|
||||||
)
|
|
||||||
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
|
|
||||||
optimizer = optim.Adam(learning_rate=lr_schedule)
|
|
||||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
|
||||||
def single_step(x, t5_feat, clip_feat, guidance):
|
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
|
||||||
x, t5_feat, clip_feat, guidance
|
|
||||||
)
|
|
||||||
grads = average_gradients(grads)
|
|
||||||
optimizer.update(flux.flow, grads)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
|
||||||
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
|
||||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
|
||||||
x, t5_feat, clip_feat, guidance
|
|
||||||
)
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
|
||||||
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
|
||||||
x, t5_feat, clip_feat, guidance
|
|
||||||
)
|
|
||||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
|
||||||
return loss, grads
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
|
||||||
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
|
||||||
x, t5_feat, clip_feat, guidance
|
|
||||||
)
|
|
||||||
grads = tree_map(
|
|
||||||
lambda a, b: (a + b) / args.grad_accumulate,
|
|
||||||
prev_grads,
|
|
||||||
grads,
|
|
||||||
)
|
|
||||||
grads = average_gradients(grads)
|
|
||||||
optimizer.update(flux.flow, grads)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
# We simply route to the appropriate step based on whether we have
|
|
||||||
# gradients from a previous step and whether we should be performing an
|
|
||||||
# update or simply computing and accumulating gradients in this step.
|
|
||||||
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
|
|
||||||
if prev_grads is None:
|
|
||||||
if perform_step:
|
|
||||||
return single_step(x, t5_feat, clip_feat, guidance), None
|
|
||||||
else:
|
|
||||||
return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
|
|
||||||
else:
|
|
||||||
if perform_step:
|
|
||||||
return (
|
|
||||||
grad_accumulate_and_step(
|
|
||||||
x, t5_feat, clip_feat, guidance, prev_grads
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return compute_loss_and_accumulate_grads(
|
|
||||||
x, t5_feat, clip_feat, guidance, prev_grads
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = load_dataset(args.dataset)
|
|
||||||
trainer = Trainer(flux, dataset, args)
|
|
||||||
trainer.encode_dataset()
|
|
||||||
|
|
||||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
|
||||||
|
|
||||||
# An initial generation to compare
|
|
||||||
generate_progress_images(0, flux, args)
|
|
||||||
|
|
||||||
grads = None
|
|
||||||
losses = []
|
|
||||||
tic = time.time()
|
|
||||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
|
||||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
|
||||||
mx.eval(loss, grads, state)
|
|
||||||
losses.append(loss.item())
|
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
|
||||||
toc = time.time()
|
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
|
||||||
print(
|
|
||||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
|
||||||
f"It/s: {10 / (toc - tic):.3f} "
|
|
||||||
f"Peak mem: {peak_mem:.3f} GB",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (i + 1) % args.progress_every == 0:
|
|
||||||
generate_progress_images(i + 1, flux, args)
|
|
||||||
|
|
||||||
if (i + 1) % args.checkpoint_every == 0:
|
|
||||||
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
|
||||||
losses = []
|
|
||||||
tic = time.time()
|
|
||||||
|
|
||||||
save_adapters("final_adapters.safetensors", flux, args)
|
|
||||||
print("Training successful.")
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
from .datasets import Dataset, load_dataset
|
|
||||||
from .flux import FluxPipeline
|
|
||||||
from .lora import LoRALinear
|
|
||||||
from .sampler import FluxSampler
|
|
||||||
from .trainer import Trainer
|
|
||||||
from .utils import (
|
|
||||||
load_ae,
|
|
||||||
load_clip,
|
|
||||||
load_clip_tokenizer,
|
|
||||||
load_flow_model,
|
|
||||||
load_t5,
|
|
||||||
load_t5_tokenizer,
|
|
||||||
save_config,
|
|
||||||
)
|
|
||||||
@@ -1,357 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
from mlx.nn.layers.upsample import upsample_nearest
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AutoEncoderParams:
|
|
||||||
resolution: int
|
|
||||||
in_channels: int
|
|
||||||
ch: int
|
|
||||||
out_ch: int
|
|
||||||
ch_mult: List[int]
|
|
||||||
num_res_blocks: int
|
|
||||||
z_channels: int
|
|
||||||
scale_factor: float
|
|
||||||
shift_factor: float
|
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels: int):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = nn.GroupNorm(
|
|
||||||
num_groups=32,
|
|
||||||
dims=in_channels,
|
|
||||||
eps=1e-6,
|
|
||||||
affine=True,
|
|
||||||
pytorch_compatible=True,
|
|
||||||
)
|
|
||||||
self.q = nn.Linear(in_channels, in_channels)
|
|
||||||
self.k = nn.Linear(in_channels, in_channels)
|
|
||||||
self.v = nn.Linear(in_channels, in_channels)
|
|
||||||
self.proj_out = nn.Linear(in_channels, in_channels)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
B, H, W, C = x.shape
|
|
||||||
|
|
||||||
y = x.reshape(B, 1, -1, C)
|
|
||||||
y = self.norm(y)
|
|
||||||
q = self.q(y)
|
|
||||||
k = self.k(y)
|
|
||||||
v = self.v(y)
|
|
||||||
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
|
|
||||||
y = self.proj_out(y)
|
|
||||||
|
|
||||||
return x + y.reshape(B, H, W, C)
|
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels: int, out_channels: int):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
self.norm1 = nn.GroupNorm(
|
|
||||||
num_groups=32,
|
|
||||||
dims=in_channels,
|
|
||||||
eps=1e-6,
|
|
||||||
affine=True,
|
|
||||||
pytorch_compatible=True,
|
|
||||||
)
|
|
||||||
self.conv1 = nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
self.norm2 = nn.GroupNorm(
|
|
||||||
num_groups=32,
|
|
||||||
dims=out_channels,
|
|
||||||
eps=1e-6,
|
|
||||||
affine=True,
|
|
||||||
pytorch_compatible=True,
|
|
||||||
)
|
|
||||||
self.conv2 = nn.Conv2d(
|
|
||||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
self.nin_shortcut = nn.Linear(in_channels, out_channels)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
h = x
|
|
||||||
h = self.norm1(h)
|
|
||||||
h = nn.silu(h)
|
|
||||||
h = self.conv1(h)
|
|
||||||
|
|
||||||
h = self.norm2(h)
|
|
||||||
h = nn.silu(h)
|
|
||||||
h = self.conv2(h)
|
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x = self.nin_shortcut(x)
|
|
||||||
|
|
||||||
return x + h
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
|
||||||
def __init__(self, in_channels: int):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
|
||||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
def __init__(self, in_channels: int):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
|
||||||
x = upsample_nearest(x, (2, 2))
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
resolution: int,
|
|
||||||
in_channels: int,
|
|
||||||
ch: int,
|
|
||||||
ch_mult: list[int],
|
|
||||||
num_res_blocks: int,
|
|
||||||
z_channels: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.ch = ch
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.in_channels = in_channels
|
|
||||||
# downsampling
|
|
||||||
self.conv_in = nn.Conv2d(
|
|
||||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_res = resolution
|
|
||||||
in_ch_mult = (1,) + tuple(ch_mult)
|
|
||||||
self.in_ch_mult = in_ch_mult
|
|
||||||
self.down = []
|
|
||||||
block_in = self.ch
|
|
||||||
for i_level in range(self.num_resolutions):
|
|
||||||
block = []
|
|
||||||
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
|
||||||
block_in = ch * in_ch_mult[i_level]
|
|
||||||
block_out = ch * ch_mult[i_level]
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
|
||||||
block_in = block_out
|
|
||||||
down = {}
|
|
||||||
down["block"] = block
|
|
||||||
down["attn"] = attn
|
|
||||||
if i_level != self.num_resolutions - 1:
|
|
||||||
down["downsample"] = Downsample(block_in)
|
|
||||||
curr_res = curr_res // 2
|
|
||||||
self.down.append(down)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
self.mid = {}
|
|
||||||
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
||||||
self.mid["attn_1"] = AttnBlock(block_in)
|
|
||||||
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
||||||
|
|
||||||
# end
|
|
||||||
self.norm_out = nn.GroupNorm(
|
|
||||||
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
|
||||||
)
|
|
||||||
self.conv_out = nn.Conv2d(
|
|
||||||
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
|
||||||
hs = [self.conv_in(x)]
|
|
||||||
for i_level in range(self.num_resolutions):
|
|
||||||
for i_block in range(self.num_res_blocks):
|
|
||||||
h = self.down[i_level]["block"][i_block](hs[-1])
|
|
||||||
|
|
||||||
# TODO: Remove the attn
|
|
||||||
if len(self.down[i_level]["attn"]) > 0:
|
|
||||||
h = self.down[i_level]["attn"][i_block](h)
|
|
||||||
|
|
||||||
hs.append(h)
|
|
||||||
|
|
||||||
if i_level != self.num_resolutions - 1:
|
|
||||||
hs.append(self.down[i_level]["downsample"](hs[-1]))
|
|
||||||
|
|
||||||
# middle
|
|
||||||
h = hs[-1]
|
|
||||||
h = self.mid["block_1"](h)
|
|
||||||
h = self.mid["attn_1"](h)
|
|
||||||
h = self.mid["block_2"](h)
|
|
||||||
|
|
||||||
# end
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nn.silu(h)
|
|
||||||
h = self.conv_out(h)
|
|
||||||
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ch: int,
|
|
||||||
out_ch: int,
|
|
||||||
ch_mult: list[int],
|
|
||||||
num_res_blocks: int,
|
|
||||||
in_channels: int,
|
|
||||||
resolution: int,
|
|
||||||
z_channels: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.ch = ch
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
|
||||||
|
|
||||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
|
||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
|
||||||
|
|
||||||
# z to block_in
|
|
||||||
self.conv_in = nn.Conv2d(
|
|
||||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
self.mid = {}
|
|
||||||
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
||||||
self.mid["attn_1"] = AttnBlock(block_in)
|
|
||||||
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
||||||
|
|
||||||
# upsampling
|
|
||||||
self.up = []
|
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
|
||||||
block = []
|
|
||||||
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
|
||||||
|
|
||||||
block_out = ch * ch_mult[i_level]
|
|
||||||
for _ in range(self.num_res_blocks + 1):
|
|
||||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
|
||||||
block_in = block_out
|
|
||||||
up = {}
|
|
||||||
up["block"] = block
|
|
||||||
up["attn"] = attn
|
|
||||||
if i_level != 0:
|
|
||||||
up["upsample"] = Upsample(block_in)
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
self.up.insert(0, up) # prepend to get consistent order
|
|
||||||
|
|
||||||
# end
|
|
||||||
self.norm_out = nn.GroupNorm(
|
|
||||||
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
|
||||||
)
|
|
||||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
|
||||||
|
|
||||||
def __call__(self, z: mx.array):
|
|
||||||
# z to block_in
|
|
||||||
h = self.conv_in(z)
|
|
||||||
|
|
||||||
# middle
|
|
||||||
h = self.mid["block_1"](h)
|
|
||||||
h = self.mid["attn_1"](h)
|
|
||||||
h = self.mid["block_2"](h)
|
|
||||||
|
|
||||||
# upsampling
|
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
h = self.up[i_level]["block"][i_block](h)
|
|
||||||
|
|
||||||
# TODO: Remove the attn
|
|
||||||
if len(self.up[i_level]["attn"]) > 0:
|
|
||||||
h = self.up[i_level]["attn"][i_block](h)
|
|
||||||
|
|
||||||
if i_level != 0:
|
|
||||||
h = self.up[i_level]["upsample"](h)
|
|
||||||
|
|
||||||
# end
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nn.silu(h)
|
|
||||||
h = self.conv_out(h)
|
|
||||||
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class DiagonalGaussian(nn.Module):
|
|
||||||
def __call__(self, z: mx.array):
|
|
||||||
mean, logvar = mx.split(z, 2, axis=-1)
|
|
||||||
if self.training:
|
|
||||||
std = mx.exp(0.5 * logvar)
|
|
||||||
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
|
|
||||||
return mean + std * eps
|
|
||||||
else:
|
|
||||||
return mean
|
|
||||||
|
|
||||||
|
|
||||||
class AutoEncoder(nn.Module):
|
|
||||||
def __init__(self, params: AutoEncoderParams):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = Encoder(
|
|
||||||
resolution=params.resolution,
|
|
||||||
in_channels=params.in_channels,
|
|
||||||
ch=params.ch,
|
|
||||||
ch_mult=params.ch_mult,
|
|
||||||
num_res_blocks=params.num_res_blocks,
|
|
||||||
z_channels=params.z_channels,
|
|
||||||
)
|
|
||||||
self.decoder = Decoder(
|
|
||||||
resolution=params.resolution,
|
|
||||||
in_channels=params.in_channels,
|
|
||||||
ch=params.ch,
|
|
||||||
out_ch=params.out_ch,
|
|
||||||
ch_mult=params.ch_mult,
|
|
||||||
num_res_blocks=params.num_res_blocks,
|
|
||||||
z_channels=params.z_channels,
|
|
||||||
)
|
|
||||||
self.reg = DiagonalGaussian()
|
|
||||||
|
|
||||||
self.scale_factor = params.scale_factor
|
|
||||||
self.shift_factor = params.shift_factor
|
|
||||||
|
|
||||||
def sanitize(self, weights):
|
|
||||||
new_weights = {}
|
|
||||||
for k, w in weights.items():
|
|
||||||
if w.ndim == 4:
|
|
||||||
w = w.transpose(0, 2, 3, 1)
|
|
||||||
w = w.reshape(-1).reshape(w.shape)
|
|
||||||
if w.shape[1:3] == (1, 1):
|
|
||||||
w = w.squeeze((1, 2))
|
|
||||||
new_weights[k] = w
|
|
||||||
return new_weights
|
|
||||||
|
|
||||||
def encode(self, x: mx.array):
|
|
||||||
z = self.reg(self.encoder(x))
|
|
||||||
z = self.scale_factor * (z - self.shift_factor)
|
|
||||||
return z
|
|
||||||
|
|
||||||
def decode(self, z: mx.array):
|
|
||||||
z = z / self.scale_factor + self.shift_factor
|
|
||||||
return self.decoder(z)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
|
||||||
return self.decode(self.encode(x))
|
|
||||||
@@ -1,154 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CLIPTextModelConfig:
|
|
||||||
num_layers: int = 23
|
|
||||||
model_dims: int = 1024
|
|
||||||
num_heads: int = 16
|
|
||||||
max_length: int = 77
|
|
||||||
vocab_size: int = 49408
|
|
||||||
hidden_act: str = "quick_gelu"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, config):
|
|
||||||
return cls(
|
|
||||||
num_layers=config["num_hidden_layers"],
|
|
||||||
model_dims=config["hidden_size"],
|
|
||||||
num_heads=config["num_attention_heads"],
|
|
||||||
max_length=config["max_position_embeddings"],
|
|
||||||
vocab_size=config["vocab_size"],
|
|
||||||
hidden_act=config["hidden_act"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CLIPOutput:
|
|
||||||
# The last_hidden_state indexed at the EOS token and possibly projected if
|
|
||||||
# the model has a projection layer
|
|
||||||
pooled_output: Optional[mx.array] = None
|
|
||||||
|
|
||||||
# The full sequence output of the transformer after the final layernorm
|
|
||||||
last_hidden_state: Optional[mx.array] = None
|
|
||||||
|
|
||||||
# A list of hidden states corresponding to the outputs of the transformer layers
|
|
||||||
hidden_states: Optional[List[mx.array]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPEncoderLayer(nn.Module):
|
|
||||||
"""The transformer encoder layer from CLIP."""
|
|
||||||
|
|
||||||
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
|
||||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
|
||||||
|
|
||||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
|
|
||||||
|
|
||||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
|
||||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
|
||||||
|
|
||||||
self.act = _ACTIVATIONS[activation]
|
|
||||||
|
|
||||||
def __call__(self, x, attn_mask=None):
|
|
||||||
y = self.layer_norm1(x)
|
|
||||||
y = self.attention(y, y, y, attn_mask)
|
|
||||||
x = y + x
|
|
||||||
|
|
||||||
y = self.layer_norm2(x)
|
|
||||||
y = self.linear1(y)
|
|
||||||
y = self.act(y)
|
|
||||||
y = self.linear2(y)
|
|
||||||
x = y + x
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextModel(nn.Module):
|
|
||||||
"""Implements the text encoder transformer from CLIP."""
|
|
||||||
|
|
||||||
def __init__(self, config: CLIPTextModelConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
|
||||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
|
||||||
self.layers = [
|
|
||||||
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
|
||||||
for i in range(config.num_layers)
|
|
||||||
]
|
|
||||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
|
||||||
|
|
||||||
def _get_mask(self, N, dtype):
|
|
||||||
indices = mx.arange(N)
|
|
||||||
mask = indices[:, None] < indices[None]
|
|
||||||
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def sanitize(self, weights):
|
|
||||||
new_weights = {}
|
|
||||||
for key, w in weights.items():
|
|
||||||
# Remove prefixes
|
|
||||||
if key.startswith("text_model."):
|
|
||||||
key = key[11:]
|
|
||||||
if key.startswith("embeddings."):
|
|
||||||
key = key[11:]
|
|
||||||
if key.startswith("encoder."):
|
|
||||||
key = key[8:]
|
|
||||||
|
|
||||||
# Map attention layers
|
|
||||||
if "self_attn." in key:
|
|
||||||
key = key.replace("self_attn.", "attention.")
|
|
||||||
if "q_proj." in key:
|
|
||||||
key = key.replace("q_proj.", "query_proj.")
|
|
||||||
if "k_proj." in key:
|
|
||||||
key = key.replace("k_proj.", "key_proj.")
|
|
||||||
if "v_proj." in key:
|
|
||||||
key = key.replace("v_proj.", "value_proj.")
|
|
||||||
|
|
||||||
# Map ffn layers
|
|
||||||
if "mlp.fc1" in key:
|
|
||||||
key = key.replace("mlp.fc1", "linear1")
|
|
||||||
if "mlp.fc2" in key:
|
|
||||||
key = key.replace("mlp.fc2", "linear2")
|
|
||||||
|
|
||||||
new_weights[key] = w
|
|
||||||
|
|
||||||
return new_weights
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
# Extract some shapes
|
|
||||||
B, N = x.shape
|
|
||||||
eos_tokens = x.argmax(-1)
|
|
||||||
|
|
||||||
# Compute the embeddings
|
|
||||||
x = self.token_embedding(x)
|
|
||||||
x = x + self.position_embedding.weight[:N]
|
|
||||||
|
|
||||||
# Compute the features from the transformer
|
|
||||||
mask = self._get_mask(N, x.dtype)
|
|
||||||
hidden_states = []
|
|
||||||
for l in self.layers:
|
|
||||||
x = l(x, mask)
|
|
||||||
hidden_states.append(x)
|
|
||||||
|
|
||||||
# Apply the final layernorm and return
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
last_hidden_state = x
|
|
||||||
|
|
||||||
# Select the EOS token
|
|
||||||
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
|
||||||
|
|
||||||
return CLIPOutput(
|
|
||||||
pooled_output=pooled_output,
|
|
||||||
last_hidden_state=last_hidden_state,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
)
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
|
||||||
def __getitem__(self, index: int):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class LocalDataset(Dataset):
|
|
||||||
prompt_key = "prompt"
|
|
||||||
|
|
||||||
def __init__(self, dataset: str, data_file):
|
|
||||||
self.dataset_base = Path(dataset)
|
|
||||||
with open(data_file, "r") as fid:
|
|
||||||
self._data = [json.loads(l) for l in fid]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._data)
|
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
|
||||||
item = self._data[index]
|
|
||||||
image = Image.open(self.dataset_base / item["image"])
|
|
||||||
return image, item[self.prompt_key]
|
|
||||||
|
|
||||||
|
|
||||||
class LegacyDataset(LocalDataset):
|
|
||||||
prompt_key = "text"
|
|
||||||
|
|
||||||
def __init__(self, dataset: str):
|
|
||||||
self.dataset_base = Path(dataset)
|
|
||||||
with open(self.dataset_base / "index.json") as f:
|
|
||||||
self._data = json.load(f)["data"]
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceDataset(Dataset):
|
|
||||||
|
|
||||||
def __init__(self, dataset: str):
|
|
||||||
from datasets import load_dataset as hf_load_dataset
|
|
||||||
|
|
||||||
self._df = hf_load_dataset(dataset)["train"]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._df)
|
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
|
||||||
item = self._df[index]
|
|
||||||
return item["image"], item["prompt"]
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(dataset: str):
|
|
||||||
dataset_base = Path(dataset)
|
|
||||||
data_file = dataset_base / "train.jsonl"
|
|
||||||
legacy_file = dataset_base / "index.json"
|
|
||||||
|
|
||||||
if data_file.exists():
|
|
||||||
print(f"Load the local dataset {data_file} .", flush=True)
|
|
||||||
dataset = LocalDataset(dataset, data_file)
|
|
||||||
elif legacy_file.exists():
|
|
||||||
print(f"Load the local dataset {legacy_file} .")
|
|
||||||
print()
|
|
||||||
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
|
|
||||||
print(" See the README for details.")
|
|
||||||
print(flush=True)
|
|
||||||
dataset = LegacyDataset(dataset)
|
|
||||||
else:
|
|
||||||
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
|
||||||
dataset = HuggingFaceDataset(dataset)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
@@ -1,246 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
from mlx.utils import tree_unflatten
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .lora import LoRALinear
|
|
||||||
from .sampler import FluxSampler
|
|
||||||
from .utils import (
|
|
||||||
load_ae,
|
|
||||||
load_clip,
|
|
||||||
load_clip_tokenizer,
|
|
||||||
load_flow_model,
|
|
||||||
load_t5,
|
|
||||||
load_t5_tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FluxPipeline:
|
|
||||||
def __init__(self, name: str, t5_padding: bool = True):
|
|
||||||
self.dtype = mx.bfloat16
|
|
||||||
self.name = name
|
|
||||||
self.t5_padding = t5_padding
|
|
||||||
|
|
||||||
self.ae = load_ae(name)
|
|
||||||
self.flow = load_flow_model(name)
|
|
||||||
self.clip = load_clip(name)
|
|
||||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
|
||||||
self.t5 = load_t5(name)
|
|
||||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
|
||||||
self.sampler = FluxSampler(name)
|
|
||||||
|
|
||||||
def ensure_models_are_loaded(self):
|
|
||||||
mx.eval(
|
|
||||||
self.ae.parameters(),
|
|
||||||
self.flow.parameters(),
|
|
||||||
self.clip.parameters(),
|
|
||||||
self.t5.parameters(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reload_text_encoders(self):
|
|
||||||
self.t5 = load_t5(self.name)
|
|
||||||
self.clip = load_clip(self.name)
|
|
||||||
|
|
||||||
def tokenize(self, text):
|
|
||||||
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
|
||||||
clip_tokens = self.clip_tokenizer.encode(text)
|
|
||||||
return t5_tokens, clip_tokens
|
|
||||||
|
|
||||||
def _prepare_latent_images(self, x):
|
|
||||||
b, h, w, c = x.shape
|
|
||||||
|
|
||||||
# Pack the latent image to 2x2 patches
|
|
||||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
|
||||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
|
||||||
|
|
||||||
# Create positions ids used to positionally encode each patch. Due to
|
|
||||||
# the way RoPE works, this results in an interesting positional
|
|
||||||
# encoding where parts of the feature are holding different positional
|
|
||||||
# information. Namely, the first part holds information independent of
|
|
||||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
|
||||||
# information and the last one horizontal.
|
|
||||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
|
||||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
|
||||||
x_ids = mx.stack([i, j, k], axis=-1)
|
|
||||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
|
||||||
|
|
||||||
return x, x_ids
|
|
||||||
|
|
||||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
|
||||||
# Prepare the text features
|
|
||||||
txt = self.t5(t5_tokens)
|
|
||||||
if len(txt) == 1 and n_images > 1:
|
|
||||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
|
||||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
|
||||||
|
|
||||||
# Prepare the clip text features
|
|
||||||
vec = self.clip(clip_tokens).pooled_output
|
|
||||||
if len(vec) == 1 and n_images > 1:
|
|
||||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
|
||||||
|
|
||||||
return txt, txt_ids, vec
|
|
||||||
|
|
||||||
def _denoising_loop(
|
|
||||||
self,
|
|
||||||
x_t,
|
|
||||||
x_ids,
|
|
||||||
txt,
|
|
||||||
txt_ids,
|
|
||||||
vec,
|
|
||||||
num_steps: int = 35,
|
|
||||||
guidance: float = 4.0,
|
|
||||||
start: float = 1,
|
|
||||||
stop: float = 0,
|
|
||||||
):
|
|
||||||
B = len(x_t)
|
|
||||||
|
|
||||||
def scalar(x):
|
|
||||||
return mx.full((B,), x, dtype=self.dtype)
|
|
||||||
|
|
||||||
guidance = scalar(guidance)
|
|
||||||
timesteps = self.sampler.timesteps(
|
|
||||||
num_steps,
|
|
||||||
x_t.shape[1],
|
|
||||||
start=start,
|
|
||||||
stop=stop,
|
|
||||||
)
|
|
||||||
for i in range(num_steps):
|
|
||||||
t = timesteps[i]
|
|
||||||
t_prev = timesteps[i + 1]
|
|
||||||
|
|
||||||
pred = self.flow(
|
|
||||||
img=x_t,
|
|
||||||
img_ids=x_ids,
|
|
||||||
txt=txt,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=vec,
|
|
||||||
timesteps=scalar(t),
|
|
||||||
guidance=guidance,
|
|
||||||
)
|
|
||||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
|
||||||
|
|
||||||
yield x_t
|
|
||||||
|
|
||||||
def generate_latents(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
n_images: int = 1,
|
|
||||||
num_steps: int = 35,
|
|
||||||
guidance: float = 4.0,
|
|
||||||
latent_size: Tuple[int, int] = (64, 64),
|
|
||||||
seed=None,
|
|
||||||
):
|
|
||||||
# Set the PRNG state
|
|
||||||
if seed is not None:
|
|
||||||
mx.random.seed(seed)
|
|
||||||
|
|
||||||
# Create the latent variables
|
|
||||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
|
||||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
|
||||||
|
|
||||||
# Get the conditioning
|
|
||||||
t5_tokens, clip_tokens = self.tokenize(text)
|
|
||||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
|
||||||
|
|
||||||
# Yield the conditioning for controlled evaluation by the caller
|
|
||||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
|
||||||
|
|
||||||
# Yield the latent sequences from the denoising loop
|
|
||||||
yield from self._denoising_loop(
|
|
||||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
|
||||||
)
|
|
||||||
|
|
||||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
|
||||||
h, w = latent_size
|
|
||||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
|
||||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
|
||||||
x = self.ae.decode(x)
|
|
||||||
return mx.clip(x + 1, 0, 2) * 0.5
|
|
||||||
|
|
||||||
def generate_images(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
n_images: int = 1,
|
|
||||||
num_steps: int = 35,
|
|
||||||
guidance: float = 4.0,
|
|
||||||
latent_size: Tuple[int, int] = (64, 64),
|
|
||||||
seed=None,
|
|
||||||
reload_text_encoders: bool = True,
|
|
||||||
progress: bool = True,
|
|
||||||
):
|
|
||||||
latents = self.generate_latents(
|
|
||||||
text, n_images, num_steps, guidance, latent_size, seed
|
|
||||||
)
|
|
||||||
mx.eval(next(latents))
|
|
||||||
|
|
||||||
if reload_text_encoders:
|
|
||||||
self.reload_text_encoders()
|
|
||||||
|
|
||||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
|
||||||
mx.eval(x_t)
|
|
||||||
|
|
||||||
images = []
|
|
||||||
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
|
|
||||||
images.append(self.decode(x_t[i : i + 1]))
|
|
||||||
mx.eval(images[-1])
|
|
||||||
images = mx.concatenate(images, axis=0)
|
|
||||||
mx.eval(images)
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
def training_loss(
|
|
||||||
self,
|
|
||||||
x_0: mx.array,
|
|
||||||
t5_features: mx.array,
|
|
||||||
clip_features: mx.array,
|
|
||||||
guidance: mx.array,
|
|
||||||
):
|
|
||||||
# Get the text conditioning
|
|
||||||
txt = t5_features
|
|
||||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
|
||||||
vec = clip_features
|
|
||||||
|
|
||||||
# Prepare the latent input
|
|
||||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
|
||||||
|
|
||||||
# Forward process
|
|
||||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
|
||||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
|
||||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
|
||||||
x_t = mx.stop_gradient(x_t)
|
|
||||||
|
|
||||||
# Do the denoising
|
|
||||||
pred = self.flow(
|
|
||||||
img=x_t,
|
|
||||||
img_ids=x_ids,
|
|
||||||
txt=txt,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=vec,
|
|
||||||
timesteps=t,
|
|
||||||
guidance=guidance,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (pred + x_0 - eps).square().mean()
|
|
||||||
|
|
||||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
|
||||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
|
||||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
|
||||||
all_blocks.reverse()
|
|
||||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
|
||||||
for i, block in zip(range(num_blocks), all_blocks):
|
|
||||||
loras = []
|
|
||||||
for name, module in block.named_modules():
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
|
||||||
block.update_modules(tree_unflatten(loras))
|
|
||||||
|
|
||||||
def fuse_lora_layers(self):
|
|
||||||
fused_layers = []
|
|
||||||
for name, module in self.flow.named_modules():
|
|
||||||
if isinstance(module, LoRALinear):
|
|
||||||
fused_layers.append((name, module.fuse()))
|
|
||||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
|
||||||
@@ -1,321 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import partial
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def _rope(pos: mx.array, dim: int, theta: float):
|
|
||||||
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
|
||||||
omega = 1.0 / (theta**scale)
|
|
||||||
x = pos[..., None] * omega
|
|
||||||
cosx = mx.cos(x)
|
|
||||||
sinx = mx.sin(x)
|
|
||||||
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
|
|
||||||
pe = pe.reshape(*pe.shape[:-1], 2, 2)
|
|
||||||
|
|
||||||
return pe
|
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
|
||||||
def _ab_plus_cd(a, b, c, d):
|
|
||||||
return a * b + c * d
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_rope(x, pe):
|
|
||||||
s = x.shape
|
|
||||||
x = x.reshape(*s[:-1], -1, 1, 2)
|
|
||||||
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
|
|
||||||
return x.reshape(s)
|
|
||||||
|
|
||||||
|
|
||||||
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
|
|
||||||
B, H, L, D = q.shape
|
|
||||||
|
|
||||||
q = _apply_rope(q, pe)
|
|
||||||
k = _apply_rope(k, pe)
|
|
||||||
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
|
|
||||||
|
|
||||||
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(
|
|
||||||
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
|
|
||||||
):
|
|
||||||
half = dim // 2
|
|
||||||
freqs = mx.arange(0, half, dtype=mx.float32) / half
|
|
||||||
freqs = freqs * (-math.log(max_period))
|
|
||||||
freqs = mx.exp(freqs)
|
|
||||||
|
|
||||||
x = (time_factor * t)[:, None] * freqs[None]
|
|
||||||
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
|
|
||||||
|
|
||||||
return x.astype(t.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.theta = theta
|
|
||||||
self.axes_dim = axes_dim
|
|
||||||
|
|
||||||
def __call__(self, ids: mx.array):
|
|
||||||
n_axes = ids.shape[-1]
|
|
||||||
pe = mx.concatenate(
|
|
||||||
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
|
||||||
axis=-3,
|
|
||||||
)
|
|
||||||
|
|
||||||
return pe[:, None]
|
|
||||||
|
|
||||||
|
|
||||||
class MLPEmbedder(nn.Module):
|
|
||||||
def __init__(self, in_dim: int, hidden_dim: int):
|
|
||||||
super().__init__()
|
|
||||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
|
||||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
return self.out_layer(nn.silu(self.in_layer(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(nn.Module):
|
|
||||||
def __init__(self, dim: int):
|
|
||||||
super().__init__()
|
|
||||||
self.query_norm = nn.RMSNorm(dim)
|
|
||||||
self.key_norm = nn.RMSNorm(dim)
|
|
||||||
|
|
||||||
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
|
|
||||||
return self.query_norm(q), self.key_norm(k)
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
|
||||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = dim // num_heads
|
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
||||||
self.norm = QKNorm(head_dim)
|
|
||||||
self.proj = nn.Linear(dim, dim)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
|
|
||||||
H = self.num_heads
|
|
||||||
B, L, _ = x.shape
|
|
||||||
qkv = self.qkv(x)
|
|
||||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
|
||||||
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
q, k = self.norm(q, k)
|
|
||||||
x = _attention(q, k, v, pe)
|
|
||||||
x = self.proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModulationOut:
|
|
||||||
shift: mx.array
|
|
||||||
scale: mx.array
|
|
||||||
gate: mx.array
|
|
||||||
|
|
||||||
|
|
||||||
class Modulation(nn.Module):
|
|
||||||
def __init__(self, dim: int, double: bool):
|
|
||||||
super().__init__()
|
|
||||||
self.is_double = double
|
|
||||||
self.multiplier = 6 if double else 3
|
|
||||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
|
||||||
x = self.lin(nn.silu(x))
|
|
||||||
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
|
|
||||||
|
|
||||||
mod1 = ModulationOut(*xs[:3])
|
|
||||||
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
|
||||||
|
|
||||||
return mod1, mod2
|
|
||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.img_mod = Modulation(hidden_size, double=True)
|
|
||||||
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
||||||
self.img_attn = SelfAttention(
|
|
||||||
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
||||||
self.img_mlp = nn.Sequential(
|
|
||||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
|
||||||
nn.GELU(approx="tanh"),
|
|
||||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.txt_mod = Modulation(hidden_size, double=True)
|
|
||||||
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
||||||
self.txt_attn = SelfAttention(
|
|
||||||
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
||||||
self.txt_mlp = nn.Sequential(
|
|
||||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
|
||||||
nn.GELU(approx="tanh"),
|
|
||||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.sharding_group = None
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
|
||||||
) -> Tuple[mx.array, mx.array]:
|
|
||||||
B, L, _ = img.shape
|
|
||||||
_, S, _ = txt.shape
|
|
||||||
H = self.num_heads
|
|
||||||
|
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
|
||||||
|
|
||||||
# prepare image for attention
|
|
||||||
img_modulated = self.img_norm1(img)
|
|
||||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
|
||||||
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
|
|
||||||
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
|
||||||
|
|
||||||
# prepare txt for attention
|
|
||||||
txt_modulated = self.txt_norm1(txt)
|
|
||||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
|
||||||
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
|
|
||||||
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
|
||||||
|
|
||||||
# run actual attention
|
|
||||||
q = mx.concatenate([txt_q, img_q], axis=2)
|
|
||||||
k = mx.concatenate([txt_k, img_k], axis=2)
|
|
||||||
v = mx.concatenate([txt_v, img_v], axis=2)
|
|
||||||
|
|
||||||
attn = _attention(q, k, v, pe)
|
|
||||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
|
||||||
|
|
||||||
# Project - cat - average - split
|
|
||||||
txt_attn = self.txt_attn.proj(txt_attn)
|
|
||||||
img_attn = self.img_attn.proj(img_attn)
|
|
||||||
if self.sharding_group is not None:
|
|
||||||
attn = mx.concatenate([txt_attn, img_attn], axis=1)
|
|
||||||
attn = mx.distributed.all_sum(attn, group=self.sharding_group)
|
|
||||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
|
||||||
|
|
||||||
# calculate the img bloks
|
|
||||||
img = img + img_mod1.gate * img_attn
|
|
||||||
img_mlp = self.img_mlp(
|
|
||||||
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate the txt bloks
|
|
||||||
txt = txt + txt_mod1.gate * txt_attn
|
|
||||||
txt_mlp = self.txt_mlp(
|
|
||||||
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.sharding_group is not None:
|
|
||||||
txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
|
|
||||||
txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
|
|
||||||
txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)
|
|
||||||
|
|
||||||
# finalize the img/txt blocks
|
|
||||||
img = img + img_mod2.gate * img_mlp
|
|
||||||
txt = txt + txt_mod2.gate * txt_mlp
|
|
||||||
|
|
||||||
return img, txt
|
|
||||||
|
|
||||||
|
|
||||||
class SingleStreamBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
qk_scale: Optional[float] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_dim = hidden_size
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = hidden_size // num_heads
|
|
||||||
self.scale = qk_scale or head_dim**-0.5
|
|
||||||
|
|
||||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
||||||
# qkv and mlp_in
|
|
||||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
|
||||||
# proj and mlp_out
|
|
||||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
|
||||||
|
|
||||||
self.norm = QKNorm(head_dim)
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
||||||
|
|
||||||
self.mlp_act = nn.GELU(approx="tanh")
|
|
||||||
self.modulation = Modulation(hidden_size, double=False)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
|
|
||||||
B, L, _ = x.shape
|
|
||||||
H = self.num_heads
|
|
||||||
|
|
||||||
mod, _ = self.modulation(vec)
|
|
||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
|
||||||
|
|
||||||
q, k, v, mlp = mx.split(
|
|
||||||
self.linear1(x_mod),
|
|
||||||
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
||||||
q, k = self.norm(q, k)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
y = _attention(q, k, v, pe)
|
|
||||||
|
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
|
||||||
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
|
|
||||||
return x + mod.gate * y
|
|
||||||
|
|
||||||
|
|
||||||
class LastLayer(nn.Module):
|
|
||||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
||||||
self.linear = nn.Linear(
|
|
||||||
hidden_size, patch_size * patch_size * out_channels, bias=True
|
|
||||||
)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array, vec: mx.array):
|
|
||||||
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
|
|
||||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
|
||||||
|
|
||||||
from .layers import (
|
|
||||||
DoubleStreamBlock,
|
|
||||||
EmbedND,
|
|
||||||
LastLayer,
|
|
||||||
MLPEmbedder,
|
|
||||||
SingleStreamBlock,
|
|
||||||
timestep_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FluxParams:
|
|
||||||
in_channels: int
|
|
||||||
vec_in_dim: int
|
|
||||||
context_in_dim: int
|
|
||||||
hidden_size: int
|
|
||||||
mlp_ratio: float
|
|
||||||
num_heads: int
|
|
||||||
depth: int
|
|
||||||
depth_single_blocks: int
|
|
||||||
axes_dim: list[int]
|
|
||||||
theta: int
|
|
||||||
qkv_bias: bool
|
|
||||||
guidance_embed: bool
|
|
||||||
|
|
||||||
|
|
||||||
class Flux(nn.Module):
|
|
||||||
def __init__(self, params: FluxParams):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.params = params
|
|
||||||
self.in_channels = params.in_channels
|
|
||||||
self.out_channels = self.in_channels
|
|
||||||
if params.hidden_size % params.num_heads != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
|
||||||
)
|
|
||||||
pe_dim = params.hidden_size // params.num_heads
|
|
||||||
if sum(params.axes_dim) != pe_dim:
|
|
||||||
raise ValueError(
|
|
||||||
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
|
||||||
)
|
|
||||||
self.hidden_size = params.hidden_size
|
|
||||||
self.num_heads = params.num_heads
|
|
||||||
self.pe_embedder = EmbedND(
|
|
||||||
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
|
||||||
)
|
|
||||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
|
||||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
|
||||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
|
||||||
self.guidance_in = (
|
|
||||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
|
||||||
if params.guidance_embed
|
|
||||||
else nn.Identity()
|
|
||||||
)
|
|
||||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
|
||||||
|
|
||||||
self.double_blocks = [
|
|
||||||
DoubleStreamBlock(
|
|
||||||
self.hidden_size,
|
|
||||||
self.num_heads,
|
|
||||||
mlp_ratio=params.mlp_ratio,
|
|
||||||
qkv_bias=params.qkv_bias,
|
|
||||||
)
|
|
||||||
for _ in range(params.depth)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.single_blocks = [
|
|
||||||
SingleStreamBlock(
|
|
||||||
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
|
|
||||||
)
|
|
||||||
for _ in range(params.depth_single_blocks)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
|
||||||
|
|
||||||
def sanitize(self, weights):
|
|
||||||
new_weights = {}
|
|
||||||
for k, w in weights.items():
|
|
||||||
if k.startswith("model.diffusion_model."):
|
|
||||||
k = k[22:]
|
|
||||||
if k.endswith(".scale"):
|
|
||||||
k = k[:-6] + ".weight"
|
|
||||||
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
|
|
||||||
if f".{seq}." in k:
|
|
||||||
k = k.replace(f".{seq}.", f".{seq}.layers.")
|
|
||||||
break
|
|
||||||
new_weights[k] = w
|
|
||||||
return new_weights
|
|
||||||
|
|
||||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
|
||||||
group = group or mx.distributed.init()
|
|
||||||
N = group.size()
|
|
||||||
if N == 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
for block in self.double_blocks:
|
|
||||||
block.num_heads //= N
|
|
||||||
block.img_attn.num_heads //= N
|
|
||||||
block.txt_attn.num_heads //= N
|
|
||||||
block.sharding_group = group
|
|
||||||
block.img_attn.qkv = shard_linear(
|
|
||||||
block.img_attn.qkv, "all-to-sharded", segments=3, group=group
|
|
||||||
)
|
|
||||||
block.txt_attn.qkv = shard_linear(
|
|
||||||
block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
|
|
||||||
)
|
|
||||||
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
|
|
||||||
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
|
|
||||||
block.img_mlp.layers[0] = shard_linear(
|
|
||||||
block.img_mlp.layers[0], "all-to-sharded", group=group
|
|
||||||
)
|
|
||||||
block.txt_mlp.layers[0] = shard_linear(
|
|
||||||
block.txt_mlp.layers[0], "all-to-sharded", group=group
|
|
||||||
)
|
|
||||||
shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
|
|
||||||
shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)
|
|
||||||
|
|
||||||
for block in self.single_blocks:
|
|
||||||
block.num_heads //= N
|
|
||||||
block.hidden_size //= N
|
|
||||||
block.linear1 = shard_linear(
|
|
||||||
block.linear1,
|
|
||||||
"all-to-sharded",
|
|
||||||
segments=[1 / 7, 2 / 7, 3 / 7],
|
|
||||||
group=group,
|
|
||||||
)
|
|
||||||
block.linear2 = shard_linear(
|
|
||||||
block.linear2, "sharded-to-all", segments=[1 / 5], group=group
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
img: mx.array,
|
|
||||||
img_ids: mx.array,
|
|
||||||
txt: mx.array,
|
|
||||||
txt_ids: mx.array,
|
|
||||||
timesteps: mx.array,
|
|
||||||
y: mx.array,
|
|
||||||
guidance: Optional[mx.array] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
img = self.img_in(img)
|
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
||||||
if self.params.guidance_embed:
|
|
||||||
if guidance is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Didn't get guidance strength for guidance distilled model."
|
|
||||||
)
|
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
||||||
vec = vec + self.vector_in(y)
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
|
||||||
pe = self.pe_embedder(ids).astype(img.dtype)
|
|
||||||
|
|
||||||
for block in self.double_blocks:
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
||||||
|
|
||||||
img = mx.concatenate([txt, img], axis=1)
|
|
||||||
for block in self.single_blocks:
|
|
||||||
img = block(img, vec=vec, pe=pe)
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
|
||||||
|
|
||||||
img = self.final_layer(img, vec)
|
|
||||||
|
|
||||||
return img
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
|
|
||||||
class FluxSampler:
|
|
||||||
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15):
|
|
||||||
self._base_shift = base_shift
|
|
||||||
self._max_shift = max_shift
|
|
||||||
self._schnell = "schnell" in name
|
|
||||||
|
|
||||||
def _time_shift(self, x, t):
|
|
||||||
x1, x2 = 256, 4096
|
|
||||||
t1, t2 = self._base_shift, self._max_shift
|
|
||||||
exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
|
|
||||||
t = exp_mu / (exp_mu + (1 / t - 1))
|
|
||||||
return t
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def timesteps(
|
|
||||||
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
|
|
||||||
):
|
|
||||||
t = mx.linspace(start, stop, num_steps + 1)
|
|
||||||
|
|
||||||
if not self._schnell:
|
|
||||||
t = self._time_shift(image_sequence_length, t)
|
|
||||||
|
|
||||||
return t.tolist()
|
|
||||||
|
|
||||||
def random_timesteps(self, B, L, dtype=mx.float32, key=None):
|
|
||||||
if self._schnell:
|
|
||||||
# TODO: Should we upweigh 1 and 0.75?
|
|
||||||
t = mx.random.randint(1, 5, shape=(B,), key=key)
|
|
||||||
t = t.astype(dtype) / 4
|
|
||||||
else:
|
|
||||||
t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
|
|
||||||
t = self._time_shift(L, t)
|
|
||||||
|
|
||||||
return t
|
|
||||||
|
|
||||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
|
||||||
return mx.random.normal(shape, dtype=dtype, key=key)
|
|
||||||
|
|
||||||
def add_noise(self, x, t, noise=None, key=None):
|
|
||||||
noise = (
|
|
||||||
noise
|
|
||||||
if noise is not None
|
|
||||||
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
|
||||||
)
|
|
||||||
t = t.reshape([-1] + [1] * (x.ndim - 1))
|
|
||||||
return x * (1 - t) + t * noise
|
|
||||||
|
|
||||||
def step(self, pred, x_t, t, t_prev):
|
|
||||||
return x_t + (t_prev - t) * pred
|
|
||||||
244
flux/flux/t5.py
244
flux/flux/t5.py
@@ -1,244 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
_SHARED_REPLACEMENT_PATTERNS = [
|
|
||||||
(".block.", ".layers."),
|
|
||||||
(".k.", ".key_proj."),
|
|
||||||
(".o.", ".out_proj."),
|
|
||||||
(".q.", ".query_proj."),
|
|
||||||
(".v.", ".value_proj."),
|
|
||||||
("shared.", "wte."),
|
|
||||||
("lm_head.", "lm_head.linear."),
|
|
||||||
(".layer.0.layer_norm.", ".ln1."),
|
|
||||||
(".layer.1.layer_norm.", ".ln2."),
|
|
||||||
(".layer.2.layer_norm.", ".ln3."),
|
|
||||||
(".final_layer_norm.", ".ln."),
|
|
||||||
(
|
|
||||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
|
||||||
"relative_attention_bias.embeddings.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
_ENCODER_REPLACEMENT_PATTERNS = [
|
|
||||||
(".layer.0.SelfAttention.", ".attention."),
|
|
||||||
(".layer.1.DenseReluDense.", ".dense."),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class T5Config:
|
|
||||||
vocab_size: int
|
|
||||||
num_layers: int
|
|
||||||
num_heads: int
|
|
||||||
relative_attention_num_buckets: int
|
|
||||||
d_kv: int
|
|
||||||
d_model: int
|
|
||||||
feed_forward_proj: str
|
|
||||||
tie_word_embeddings: bool
|
|
||||||
|
|
||||||
d_ff: Optional[int] = None
|
|
||||||
num_decoder_layers: Optional[int] = None
|
|
||||||
relative_attention_max_distance: int = 128
|
|
||||||
layer_norm_epsilon: float = 1e-6
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, config):
|
|
||||||
return cls(
|
|
||||||
vocab_size=config["vocab_size"],
|
|
||||||
num_layers=config["num_layers"],
|
|
||||||
num_heads=config["num_heads"],
|
|
||||||
relative_attention_num_buckets=config["relative_attention_num_buckets"],
|
|
||||||
d_kv=config["d_kv"],
|
|
||||||
d_model=config["d_model"],
|
|
||||||
feed_forward_proj=config["feed_forward_proj"],
|
|
||||||
tie_word_embeddings=config["tie_word_embeddings"],
|
|
||||||
d_ff=config.get("d_ff", 4 * config["d_model"]),
|
|
||||||
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
|
|
||||||
relative_attention_max_distance=config.get(
|
|
||||||
"relative_attention_max_distance", 128
|
|
||||||
),
|
|
||||||
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionBias(nn.Module):
|
|
||||||
def __init__(self, config: T5Config, bidirectional: bool):
|
|
||||||
self.bidirectional = bidirectional
|
|
||||||
self.num_buckets = config.relative_attention_num_buckets
|
|
||||||
self.max_distance = config.relative_attention_max_distance
|
|
||||||
self.n_heads = config.num_heads
|
|
||||||
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
|
|
||||||
num_buckets = num_buckets // 2 if bidirectional else num_buckets
|
|
||||||
max_exact = num_buckets // 2
|
|
||||||
|
|
||||||
abspos = rpos.abs()
|
|
||||||
is_small = abspos < max_exact
|
|
||||||
|
|
||||||
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
|
|
||||||
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
|
|
||||||
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
|
|
||||||
|
|
||||||
buckets = mx.where(is_small, abspos, buckets_large)
|
|
||||||
if bidirectional:
|
|
||||||
buckets = buckets + (rpos > 0) * num_buckets
|
|
||||||
else:
|
|
||||||
buckets = buckets * (rpos < 0)
|
|
||||||
|
|
||||||
return buckets
|
|
||||||
|
|
||||||
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
|
||||||
"""Compute binned relative position bias"""
|
|
||||||
context_position = mx.arange(offset, query_length)[:, None]
|
|
||||||
memory_position = mx.arange(key_length)[None, :]
|
|
||||||
|
|
||||||
# shape (query_length, key_length)
|
|
||||||
relative_position = memory_position - context_position
|
|
||||||
relative_position_bucket = self._relative_position_bucket(
|
|
||||||
relative_position,
|
|
||||||
bidirectional=self.bidirectional,
|
|
||||||
num_buckets=self.num_buckets,
|
|
||||||
max_distance=self.max_distance,
|
|
||||||
)
|
|
||||||
|
|
||||||
# shape (query_length, key_length, num_heads)
|
|
||||||
values = self.embeddings(relative_position_bucket)
|
|
||||||
|
|
||||||
# shape (num_heads, query_length, key_length)
|
|
||||||
return values.transpose(2, 0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
|
||||||
def __init__(self, config: T5Config):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = config.d_kv * config.num_heads
|
|
||||||
self.num_heads = config.num_heads
|
|
||||||
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
||||||
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
||||||
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
||||||
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
queries: mx.array,
|
|
||||||
keys: mx.array,
|
|
||||||
values: mx.array,
|
|
||||||
mask: Optional[mx.array],
|
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
||||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, _ = queries.shape
|
|
||||||
_, S, _ = keys.shape
|
|
||||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
key_cache, value_cache = cache
|
|
||||||
keys = mx.concatenate([key_cache, keys], axis=3)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
|
|
||||||
values_hat = mx.fast.scaled_dot_product_attention(
|
|
||||||
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
|
|
||||||
)
|
|
||||||
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
|
|
||||||
return self.out_proj(values_hat), (keys, values)
|
|
||||||
|
|
||||||
|
|
||||||
class DenseActivation(nn.Module):
|
|
||||||
def __init__(self, config: T5Config):
|
|
||||||
super().__init__()
|
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
|
||||||
self.gated = config.feed_forward_proj.startswith("gated")
|
|
||||||
if self.gated:
|
|
||||||
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
||||||
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
||||||
else:
|
|
||||||
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
||||||
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
|
||||||
activation = config.feed_forward_proj.removeprefix("gated-")
|
|
||||||
if activation == "relu":
|
|
||||||
self.act = nn.relu
|
|
||||||
elif activation == "gelu":
|
|
||||||
self.act = nn.gelu
|
|
||||||
elif activation == "silu":
|
|
||||||
self.act = nn.silu
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown activation: {activation}")
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
if self.gated:
|
|
||||||
hidden_act = self.act(self.wi_0(x))
|
|
||||||
hidden_linear = self.wi_1(x)
|
|
||||||
x = hidden_act * hidden_linear
|
|
||||||
else:
|
|
||||||
x = self.act(self.wi(x))
|
|
||||||
return self.wo(x)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
|
||||||
def __init__(self, config: T5Config):
|
|
||||||
super().__init__()
|
|
||||||
self.attention = MultiHeadAttention(config)
|
|
||||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
||||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
||||||
self.dense = DenseActivation(config)
|
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
|
||||||
y = self.ln1(x)
|
|
||||||
y, _ = self.attention(y, y, y, mask=mask)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.ln2(x)
|
|
||||||
y = self.dense(y)
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
|
||||||
def __init__(self, config: T5Config):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = [
|
|
||||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
|
||||||
]
|
|
||||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
||||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
|
||||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
|
||||||
pos_bias = pos_bias.astype(x.dtype)
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x, mask=pos_bias)
|
|
||||||
return self.ln(x)
|
|
||||||
|
|
||||||
|
|
||||||
class T5Encoder(nn.Module):
|
|
||||||
def __init__(self, config: T5Config):
|
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
|
||||||
self.encoder = TransformerEncoder(config)
|
|
||||||
|
|
||||||
def sanitize(self, weights):
|
|
||||||
new_weights = {}
|
|
||||||
for k, w in weights.items():
|
|
||||||
for old, new in _SHARED_REPLACEMENT_PATTERNS:
|
|
||||||
k = k.replace(old, new)
|
|
||||||
if k.startswith("encoder."):
|
|
||||||
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
|
|
||||||
k = k.replace(old, new)
|
|
||||||
new_weights[k] = w
|
|
||||||
return new_weights
|
|
||||||
|
|
||||||
def __call__(self, inputs: mx.array):
|
|
||||||
return self.encoder(self.wte(inputs))
|
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import regex
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTokenizer:
|
|
||||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
|
||||||
|
|
||||||
def __init__(self, bpe_ranks, vocab, max_length=77):
|
|
||||||
self.max_length = max_length
|
|
||||||
self.bpe_ranks = bpe_ranks
|
|
||||||
self.vocab = vocab
|
|
||||||
self.pat = regex.compile(
|
|
||||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
|
||||||
regex.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bos(self):
|
|
||||||
return "<|startoftext|>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bos_token(self):
|
|
||||||
return self.vocab[self.bos]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos(self):
|
|
||||||
return "<|endoftext|>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos_token(self):
|
|
||||||
return self.vocab[self.eos]
|
|
||||||
|
|
||||||
def bpe(self, text):
|
|
||||||
if text in self._cache:
|
|
||||||
return self._cache[text]
|
|
||||||
|
|
||||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
|
||||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
|
||||||
|
|
||||||
if not unique_bigrams:
|
|
||||||
return unigrams
|
|
||||||
|
|
||||||
# In every iteration try to merge the two most likely bigrams. If none
|
|
||||||
# was merged we are done.
|
|
||||||
#
|
|
||||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
|
||||||
while unique_bigrams:
|
|
||||||
bigram = min(
|
|
||||||
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
|
||||||
)
|
|
||||||
if bigram not in self.bpe_ranks:
|
|
||||||
break
|
|
||||||
|
|
||||||
new_unigrams = []
|
|
||||||
skip = False
|
|
||||||
for a, b in zip(unigrams, unigrams[1:]):
|
|
||||||
if skip:
|
|
||||||
skip = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (a, b) == bigram:
|
|
||||||
new_unigrams.append(a + b)
|
|
||||||
skip = True
|
|
||||||
|
|
||||||
else:
|
|
||||||
new_unigrams.append(a)
|
|
||||||
|
|
||||||
if not skip:
|
|
||||||
new_unigrams.append(b)
|
|
||||||
|
|
||||||
unigrams = new_unigrams
|
|
||||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
|
||||||
|
|
||||||
self._cache[text] = unigrams
|
|
||||||
|
|
||||||
return unigrams
|
|
||||||
|
|
||||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
|
||||||
if isinstance(text, list):
|
|
||||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
|
||||||
|
|
||||||
# Lower case cleanup and split according to self.pat. Hugging Face does
|
|
||||||
# a much more thorough job here but this should suffice for 95% of
|
|
||||||
# cases.
|
|
||||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
|
||||||
tokens = regex.findall(self.pat, clean_text)
|
|
||||||
|
|
||||||
# Split the tokens according to the byte-pair merge file
|
|
||||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
|
||||||
|
|
||||||
# Map to token ids and return
|
|
||||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
|
||||||
if prepend_bos:
|
|
||||||
tokens = [self.bos_token] + tokens
|
|
||||||
if append_eos:
|
|
||||||
tokens.append(self.eos_token)
|
|
||||||
|
|
||||||
if len(tokens) > self.max_length:
|
|
||||||
tokens = tokens[: self.max_length]
|
|
||||||
if append_eos:
|
|
||||||
tokens[-1] = self.eos_token
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
if not isinstance(text, list):
|
|
||||||
return self.encode([text])
|
|
||||||
|
|
||||||
tokens = self.tokenize(text)
|
|
||||||
length = max(len(t) for t in tokens)
|
|
||||||
for t in tokens:
|
|
||||||
t.extend([self.eos_token] * (length - len(t)))
|
|
||||||
|
|
||||||
return mx.array(tokens)
|
|
||||||
|
|
||||||
|
|
||||||
class T5Tokenizer:
|
|
||||||
def __init__(self, model_file, max_length=512):
|
|
||||||
self._tokenizer = SentencePieceProcessor(model_file)
|
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad(self):
|
|
||||||
try:
|
|
||||||
return self._tokenizer.id_to_piece(self.pad_token)
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad_token(self):
|
|
||||||
return self._tokenizer.pad_id()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bos(self):
|
|
||||||
try:
|
|
||||||
return self._tokenizer.id_to_piece(self.bos_token)
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bos_token(self):
|
|
||||||
return self._tokenizer.bos_id()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos(self):
|
|
||||||
try:
|
|
||||||
return self._tokenizer.id_to_piece(self.eos_token)
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos_token(self):
|
|
||||||
return self._tokenizer.eos_id()
|
|
||||||
|
|
||||||
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
|
|
||||||
if isinstance(text, list):
|
|
||||||
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
|
|
||||||
|
|
||||||
tokens = self._tokenizer.encode(text)
|
|
||||||
|
|
||||||
if prepend_bos and self.bos_token >= 0:
|
|
||||||
tokens = [self.bos_token] + tokens
|
|
||||||
if append_eos and self.eos_token >= 0:
|
|
||||||
tokens.append(self.eos_token)
|
|
||||||
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
|
|
||||||
tokens += [self.pad_token] * (self.max_length - len(tokens))
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def encode(self, text, pad=True):
|
|
||||||
if not isinstance(text, list):
|
|
||||||
return self.encode([text], pad=pad)
|
|
||||||
|
|
||||||
pad_token = self.pad_token if self.pad_token >= 0 else 0
|
|
||||||
tokens = self.tokenize(text, pad=pad)
|
|
||||||
length = max(len(t) for t in tokens)
|
|
||||||
for t in tokens:
|
|
||||||
t.extend([pad_token] * (length - len(t)))
|
|
||||||
|
|
||||||
return mx.array(tokens)
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageFile
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .datasets import Dataset
|
|
||||||
from .flux import FluxPipeline
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
|
|
||||||
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
|
||||||
self.flux = flux
|
|
||||||
self.dataset = dataset
|
|
||||||
self.args = args
|
|
||||||
self.latents = []
|
|
||||||
self.t5_features = []
|
|
||||||
self.clip_features = []
|
|
||||||
|
|
||||||
def _random_crop_resize(self, img):
|
|
||||||
resolution = self.args.resolution
|
|
||||||
width, height = img.size
|
|
||||||
|
|
||||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
|
||||||
|
|
||||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
|
||||||
crop_size = (
|
|
||||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
|
||||||
max((0.8 + 0.2 * b) * height, resolution[1]),
|
|
||||||
)
|
|
||||||
pan = (width - crop_size[0], height - crop_size[1])
|
|
||||||
img = img.crop(
|
|
||||||
(
|
|
||||||
pan[0] * c,
|
|
||||||
pan[1] * d,
|
|
||||||
crop_size[0] + pan[0] * c,
|
|
||||||
crop_size[1] + pan[1] * d,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fit the largest rectangle with the ratio of resolution in the image
|
|
||||||
# rectangle.
|
|
||||||
width, height = crop_size
|
|
||||||
ratio = resolution[0] / resolution[1]
|
|
||||||
r1 = (height * ratio, height)
|
|
||||||
r2 = (width, width / ratio)
|
|
||||||
r = r1 if r1[0] <= width else r2
|
|
||||||
img = img.crop(
|
|
||||||
(
|
|
||||||
(width - r[0]) / 2,
|
|
||||||
(height - r[1]) / 2,
|
|
||||||
(width + r[0]) / 2,
|
|
||||||
(height + r[1]) / 2,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Finally resize the image to resolution
|
|
||||||
img = img.resize(resolution, Image.LANCZOS)
|
|
||||||
|
|
||||||
return mx.array(np.array(img))
|
|
||||||
|
|
||||||
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
|
||||||
for i in range(num_augmentations):
|
|
||||||
img = self._random_crop_resize(input_img)
|
|
||||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
|
||||||
x_0 = self.flux.ae.encode(img[None])
|
|
||||||
x_0 = x_0.astype(self.flux.dtype)
|
|
||||||
mx.eval(x_0)
|
|
||||||
self.latents.append(x_0)
|
|
||||||
|
|
||||||
def _encode_prompt(self, prompt):
|
|
||||||
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
|
||||||
t5_feat = self.flux.t5(t5_tok)
|
|
||||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
|
||||||
mx.eval(t5_feat, clip_feat)
|
|
||||||
self.t5_features.append(t5_feat)
|
|
||||||
self.clip_features.append(clip_feat)
|
|
||||||
|
|
||||||
def encode_dataset(self):
|
|
||||||
"""Encode the images & prompt in the latent space to prepare for training."""
|
|
||||||
self.flux.ae.eval()
|
|
||||||
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
|
||||||
self._encode_image(image, self.args.num_augmentations)
|
|
||||||
self._encode_prompt(prompt)
|
|
||||||
|
|
||||||
def iterate(self, batch_size):
|
|
||||||
xs = mx.concatenate(self.latents)
|
|
||||||
t5 = mx.concatenate(self.t5_features)
|
|
||||||
clip = mx.concatenate(self.clip_features)
|
|
||||||
mx.eval(xs, t5, clip)
|
|
||||||
n_aug = self.args.num_augmentations
|
|
||||||
while True:
|
|
||||||
x_indices = mx.random.permutation(len(self.latents))
|
|
||||||
c_indices = x_indices // n_aug
|
|
||||||
for i in range(0, len(self.latents), batch_size):
|
|
||||||
x_i = x_indices[i : i + batch_size]
|
|
||||||
c_i = c_indices[i : i + batch_size]
|
|
||||||
yield xs[x_i], t5[c_i], clip[c_i]
|
|
||||||
@@ -1,230 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
from .autoencoder import AutoEncoder, AutoEncoderParams
|
|
||||||
from .clip import CLIPTextModel, CLIPTextModelConfig
|
|
||||||
from .model import Flux, FluxParams
|
|
||||||
from .t5 import T5Config, T5Encoder
|
|
||||||
from .tokenizers import CLIPTokenizer, T5Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelSpec:
|
|
||||||
params: FluxParams
|
|
||||||
ae_params: AutoEncoderParams
|
|
||||||
ckpt_path: Optional[str]
|
|
||||||
ae_path: Optional[str]
|
|
||||||
repo_id: Optional[str]
|
|
||||||
repo_flow: Optional[str]
|
|
||||||
repo_ae: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
configs = {
|
|
||||||
"flux-dev": ModelSpec(
|
|
||||||
repo_id="black-forest-labs/FLUX.1-dev",
|
|
||||||
repo_flow="flux1-dev.safetensors",
|
|
||||||
repo_ae="ae.safetensors",
|
|
||||||
ckpt_path=os.getenv("FLUX_DEV"),
|
|
||||||
params=FluxParams(
|
|
||||||
in_channels=64,
|
|
||||||
vec_in_dim=768,
|
|
||||||
context_in_dim=4096,
|
|
||||||
hidden_size=3072,
|
|
||||||
mlp_ratio=4.0,
|
|
||||||
num_heads=24,
|
|
||||||
depth=19,
|
|
||||||
depth_single_blocks=38,
|
|
||||||
axes_dim=[16, 56, 56],
|
|
||||||
theta=10_000,
|
|
||||||
qkv_bias=True,
|
|
||||||
guidance_embed=True,
|
|
||||||
),
|
|
||||||
ae_path=os.getenv("AE"),
|
|
||||||
ae_params=AutoEncoderParams(
|
|
||||||
resolution=256,
|
|
||||||
in_channels=3,
|
|
||||||
ch=128,
|
|
||||||
out_ch=3,
|
|
||||||
ch_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
z_channels=16,
|
|
||||||
scale_factor=0.3611,
|
|
||||||
shift_factor=0.1159,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
"flux-schnell": ModelSpec(
|
|
||||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
|
||||||
repo_flow="flux1-schnell.safetensors",
|
|
||||||
repo_ae="ae.safetensors",
|
|
||||||
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
|
||||||
params=FluxParams(
|
|
||||||
in_channels=64,
|
|
||||||
vec_in_dim=768,
|
|
||||||
context_in_dim=4096,
|
|
||||||
hidden_size=3072,
|
|
||||||
mlp_ratio=4.0,
|
|
||||||
num_heads=24,
|
|
||||||
depth=19,
|
|
||||||
depth_single_blocks=38,
|
|
||||||
axes_dim=[16, 56, 56],
|
|
||||||
theta=10_000,
|
|
||||||
qkv_bias=True,
|
|
||||||
guidance_embed=False,
|
|
||||||
),
|
|
||||||
ae_path=os.getenv("AE"),
|
|
||||||
ae_params=AutoEncoderParams(
|
|
||||||
resolution=256,
|
|
||||||
in_channels=3,
|
|
||||||
ch=128,
|
|
||||||
out_ch=3,
|
|
||||||
ch_mult=[1, 2, 4, 4],
|
|
||||||
num_res_blocks=2,
|
|
||||||
z_channels=16,
|
|
||||||
scale_factor=0.3611,
|
|
||||||
shift_factor=0.1159,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_flow_model(name: str, hf_download: bool = True):
|
|
||||||
# Get the safetensors file to load
|
|
||||||
ckpt_path = configs[name].ckpt_path
|
|
||||||
|
|
||||||
# Download if needed
|
|
||||||
if (
|
|
||||||
ckpt_path is None
|
|
||||||
and configs[name].repo_id is not None
|
|
||||||
and configs[name].repo_flow is not None
|
|
||||||
and hf_download
|
|
||||||
):
|
|
||||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
|
||||||
|
|
||||||
# Make the model
|
|
||||||
model = Flux(configs[name].params)
|
|
||||||
|
|
||||||
# Load the checkpoint if needed
|
|
||||||
if ckpt_path is not None:
|
|
||||||
weights = mx.load(ckpt_path)
|
|
||||||
weights = model.sanitize(weights)
|
|
||||||
model.load_weights(list(weights.items()))
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_ae(name: str, hf_download: bool = True):
|
|
||||||
# Get the safetensors file to load
|
|
||||||
ckpt_path = configs[name].ae_path
|
|
||||||
|
|
||||||
# Download if needed
|
|
||||||
if (
|
|
||||||
ckpt_path is None
|
|
||||||
and configs[name].repo_id is not None
|
|
||||||
and configs[name].repo_ae is not None
|
|
||||||
and hf_download
|
|
||||||
):
|
|
||||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
|
||||||
|
|
||||||
# Make the autoencoder
|
|
||||||
ae = AutoEncoder(configs[name].ae_params)
|
|
||||||
|
|
||||||
# Load the checkpoint if needed
|
|
||||||
if ckpt_path is not None:
|
|
||||||
weights = mx.load(ckpt_path)
|
|
||||||
weights = ae.sanitize(weights)
|
|
||||||
ae.load_weights(list(weights.items()))
|
|
||||||
|
|
||||||
return ae
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip(name: str):
|
|
||||||
# Load the config
|
|
||||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
|
|
||||||
with open(config_path) as f:
|
|
||||||
config = CLIPTextModelConfig.from_dict(json.load(f))
|
|
||||||
|
|
||||||
# Make the clip text encoder
|
|
||||||
clip = CLIPTextModel(config)
|
|
||||||
|
|
||||||
# Load the weights
|
|
||||||
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
|
|
||||||
weights = mx.load(ckpt_path)
|
|
||||||
weights = clip.sanitize(weights)
|
|
||||||
clip.load_weights(list(weights.items()))
|
|
||||||
|
|
||||||
return clip
|
|
||||||
|
|
||||||
|
|
||||||
def load_t5(name: str):
|
|
||||||
# Load the config
|
|
||||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
|
|
||||||
with open(config_path) as f:
|
|
||||||
config = T5Config.from_dict(json.load(f))
|
|
||||||
|
|
||||||
# Make the T5 model
|
|
||||||
t5 = T5Encoder(config)
|
|
||||||
|
|
||||||
# Load the weights
|
|
||||||
model_index = hf_hub_download(
|
|
||||||
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
|
|
||||||
)
|
|
||||||
weight_files = set()
|
|
||||||
with open(model_index) as f:
|
|
||||||
for _, w in json.load(f)["weight_map"].items():
|
|
||||||
weight_files.add(w)
|
|
||||||
weights = {}
|
|
||||||
for w in weight_files:
|
|
||||||
w = f"text_encoder_2/{w}"
|
|
||||||
w = hf_hub_download(configs[name].repo_id, w)
|
|
||||||
weights.update(mx.load(w))
|
|
||||||
weights = t5.sanitize(weights)
|
|
||||||
t5.load_weights(list(weights.items()))
|
|
||||||
|
|
||||||
return t5
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip_tokenizer(name: str):
|
|
||||||
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
|
|
||||||
with open(vocab_file, encoding="utf-8") as f:
|
|
||||||
vocab = json.load(f)
|
|
||||||
|
|
||||||
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
|
|
||||||
with open(merges_file, encoding="utf-8") as f:
|
|
||||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
|
||||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
|
||||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
|
||||||
|
|
||||||
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
|
|
||||||
|
|
||||||
|
|
||||||
def load_t5_tokenizer(name: str, pad: bool = True):
|
|
||||||
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
|
||||||
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
|
||||||
config: dict,
|
|
||||||
config_path: Union[str, Path],
|
|
||||||
) -> None:
|
|
||||||
"""Save the model configuration to the ``config_path``.
|
|
||||||
|
|
||||||
The final configuration will be sorted before saving for better readability.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): The model configuration.
|
|
||||||
config_path (Union[str, Path]): Model configuration file path.
|
|
||||||
"""
|
|
||||||
# Sort the config for better readability
|
|
||||||
config = dict(sorted(config.items()))
|
|
||||||
|
|
||||||
# Write the config to the provided file
|
|
||||||
with open(config_path, "w") as fid:
|
|
||||||
json.dump(config, fid, indent=4)
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from flux import FluxPipeline
|
|
||||||
|
|
||||||
|
|
||||||
def print_zero(group, *args, **kwargs):
|
|
||||||
if group.rank() == 0:
|
|
||||||
flush = kwargs.pop("flush", True)
|
|
||||||
print(*args, **kwargs, flush=flush)
|
|
||||||
|
|
||||||
|
|
||||||
def quantization_predicate(name, m):
|
|
||||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
|
||||||
|
|
||||||
|
|
||||||
def to_latent_size(image_size):
|
|
||||||
h, w = image_size
|
|
||||||
h = ((h + 15) // 16) * 16
|
|
||||||
w = ((w + 15) // 16) * 16
|
|
||||||
|
|
||||||
if (h, w) != image_size:
|
|
||||||
print(
|
|
||||||
"Warning: The image dimensions need to be divisible by 16px. "
|
|
||||||
f"Changing size to {h}x{w}."
|
|
||||||
)
|
|
||||||
|
|
||||||
return (h // 8, w // 8)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate images from a textual prompt using FLUX"
|
|
||||||
)
|
|
||||||
parser.add_argument("--quantize", "-q", action="store_true")
|
|
||||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
|
||||||
parser.add_argument("--output", default="out.png")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
|
|
||||||
|
|
||||||
if args.quantize:
|
|
||||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
|
||||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
|
||||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
|
||||||
|
|
||||||
group = mx.distributed.init()
|
|
||||||
if group.size() > 1:
|
|
||||||
flux.flow.shard(group)
|
|
||||||
|
|
||||||
print_zero(group, "Loading models")
|
|
||||||
flux.ensure_models_are_loaded()
|
|
||||||
|
|
||||||
def print_help():
|
|
||||||
print_zero(group, "The command list:")
|
|
||||||
print_zero(group, "- 'q' to exit")
|
|
||||||
print_zero(group, "- 's HxW' to change the size of the image")
|
|
||||||
print_zero(group, "- 'n S' to change the number of steps")
|
|
||||||
print_zero(group, "- 'h' to print this help")
|
|
||||||
|
|
||||||
print_zero(group, "FLUX interactive session")
|
|
||||||
print_help()
|
|
||||||
seed = 0
|
|
||||||
size = (512, 512)
|
|
||||||
latent_size = to_latent_size(size)
|
|
||||||
steps = 50 if args.model == "dev" else 4
|
|
||||||
while True:
|
|
||||||
prompt = input(">> " if group.rank() == 0 else "")
|
|
||||||
if prompt == "q":
|
|
||||||
break
|
|
||||||
if prompt == "h":
|
|
||||||
print_help()
|
|
||||||
continue
|
|
||||||
if prompt.startswith("s "):
|
|
||||||
size = tuple([int(xi) for xi in prompt[2:].split("x")])
|
|
||||||
print_zero(group, "Setting the size to", size)
|
|
||||||
latent_size = to_latent_size(size)
|
|
||||||
continue
|
|
||||||
if prompt.startswith("n "):
|
|
||||||
steps = int(prompt[2:])
|
|
||||||
print_zero(group, "Setting the steps to", steps)
|
|
||||||
continue
|
|
||||||
|
|
||||||
seed += 1
|
|
||||||
latents = flux.generate_latents(
|
|
||||||
prompt,
|
|
||||||
n_images=1,
|
|
||||||
num_steps=steps,
|
|
||||||
latent_size=latent_size,
|
|
||||||
guidance=4.0,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
print_zero(group, "Processing prompt")
|
|
||||||
mx.eval(next(latents))
|
|
||||||
print_zero(group, "Generating latents")
|
|
||||||
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
|
|
||||||
mx.eval(xt)
|
|
||||||
print_zero(group, "Generating image")
|
|
||||||
xt = flux.decode(xt, latent_size)
|
|
||||||
xt = (xt * 255).astype(mx.uint8)
|
|
||||||
mx.eval(xt)
|
|
||||||
im = Image.fromarray(np.array(xt[0]))
|
|
||||||
im.save(args.output)
|
|
||||||
print_zero(group, "Saved at", args.output, end="\n\n")
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
mlx>=0.18.1
|
|
||||||
huggingface-hub
|
|
||||||
regex
|
|
||||||
numpy
|
|
||||||
tqdm
|
|
||||||
Pillow
|
|
||||||
sentencepiece
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 754 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 423 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 434 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 153 KiB |
@@ -1,175 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from flux import FluxPipeline
|
|
||||||
|
|
||||||
|
|
||||||
def to_latent_size(image_size):
|
|
||||||
h, w = image_size
|
|
||||||
h = ((h + 15) // 16) * 16
|
|
||||||
w = ((w + 15) // 16) * 16
|
|
||||||
|
|
||||||
if (h, w) != image_size:
|
|
||||||
print(
|
|
||||||
"Warning: The image dimensions need to be divisible by 16px. "
|
|
||||||
f"Changing size to {h}x{w}."
|
|
||||||
)
|
|
||||||
|
|
||||||
return (h // 8, w // 8)
|
|
||||||
|
|
||||||
|
|
||||||
def quantization_predicate(name, m):
|
|
||||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(flux, adapter_file, fuse=False):
|
|
||||||
weights, lora_config = mx.load(adapter_file, return_metadata=True)
|
|
||||||
rank = int(lora_config["lora_rank"])
|
|
||||||
num_blocks = int(lora_config["lora_blocks"])
|
|
||||||
flux.linear_to_lora_layers(rank, num_blocks)
|
|
||||||
flux.flow.load_weights(list(weights.items()), strict=False)
|
|
||||||
if fuse:
|
|
||||||
flux.fuse_lora_layers()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate images from a textual prompt using FLUX"
|
|
||||||
)
|
|
||||||
parser.add_argument("prompt")
|
|
||||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
|
||||||
parser.add_argument("--n-images", type=int, default=4)
|
|
||||||
parser.add_argument(
|
|
||||||
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
|
|
||||||
)
|
|
||||||
parser.add_argument("--steps", type=int)
|
|
||||||
parser.add_argument("--guidance", type=float, default=4.0)
|
|
||||||
parser.add_argument("--n-rows", type=int, default=1)
|
|
||||||
parser.add_argument("--decoding-batch-size", type=int, default=1)
|
|
||||||
parser.add_argument("--quantize", "-q", action="store_true")
|
|
||||||
parser.add_argument("--preload-models", action="store_true")
|
|
||||||
parser.add_argument("--output", default="out.png")
|
|
||||||
parser.add_argument("--save-raw", action="store_true")
|
|
||||||
parser.add_argument("--seed", type=int)
|
|
||||||
parser.add_argument("--verbose", "-v", action="store_true")
|
|
||||||
parser.add_argument("--adapter")
|
|
||||||
parser.add_argument("--fuse-adapter", action="store_true")
|
|
||||||
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
|
||||||
parser.add_argument("--force-shard", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Load the models
|
|
||||||
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
|
|
||||||
args.steps = args.steps or (50 if args.model == "dev" else 2)
|
|
||||||
|
|
||||||
if args.adapter:
|
|
||||||
load_adapter(flux, args.adapter, fuse=args.fuse_adapter)
|
|
||||||
|
|
||||||
if args.quantize:
|
|
||||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
|
||||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
|
||||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
|
||||||
|
|
||||||
# Figure out what kind of distributed generation we should do
|
|
||||||
group = mx.distributed.init()
|
|
||||||
n_images = args.n_images
|
|
||||||
should_gather = False
|
|
||||||
if group.size() > 1:
|
|
||||||
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
|
||||||
flux.flow.shard(group)
|
|
||||||
else:
|
|
||||||
n_images //= group.size()
|
|
||||||
should_gather = True
|
|
||||||
|
|
||||||
# If we are sharding we should have the same seed and if we are doing
|
|
||||||
# data parallel generation we should have different seeds
|
|
||||||
if args.seed is None:
|
|
||||||
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
|
||||||
if should_gather:
|
|
||||||
args.seed = args.seed + group.rank()
|
|
||||||
|
|
||||||
if args.preload_models:
|
|
||||||
flux.ensure_models_are_loaded()
|
|
||||||
|
|
||||||
# Make the generator
|
|
||||||
latent_size = to_latent_size(args.image_size)
|
|
||||||
latents = flux.generate_latents(
|
|
||||||
args.prompt,
|
|
||||||
n_images=n_images,
|
|
||||||
num_steps=args.steps,
|
|
||||||
latent_size=latent_size,
|
|
||||||
guidance=args.guidance,
|
|
||||||
seed=args.seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
# First we get and eval the conditioning
|
|
||||||
conditioning = next(latents)
|
|
||||||
mx.eval(conditioning)
|
|
||||||
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
|
|
||||||
mx.reset_peak_memory()
|
|
||||||
|
|
||||||
# The following is not necessary but it may help in memory constrained
|
|
||||||
# systems by reusing the memory kept by the text encoders.
|
|
||||||
del flux.t5
|
|
||||||
del flux.clip
|
|
||||||
|
|
||||||
# Actual denoising loop
|
|
||||||
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
|
|
||||||
mx.eval(x_t)
|
|
||||||
|
|
||||||
# The following is not necessary but it may help in memory constrained
|
|
||||||
# systems by reusing the memory kept by the flow transformer.
|
|
||||||
del flux.flow
|
|
||||||
peak_mem_generation = mx.get_peak_memory() / 1024**3
|
|
||||||
mx.reset_peak_memory()
|
|
||||||
|
|
||||||
# Decode them into images
|
|
||||||
decoded = []
|
|
||||||
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
|
|
||||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
|
||||||
mx.eval(decoded[-1])
|
|
||||||
peak_mem_decoding = mx.get_peak_memory() / 1024**3
|
|
||||||
peak_mem_overall = max(
|
|
||||||
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
|
||||||
)
|
|
||||||
|
|
||||||
# Gather them if each node has different images
|
|
||||||
decoded = mx.concatenate(decoded, axis=0)
|
|
||||||
if should_gather:
|
|
||||||
decoded = mx.distributed.all_gather(decoded)
|
|
||||||
mx.eval(decoded)
|
|
||||||
|
|
||||||
if args.save_raw:
|
|
||||||
*name, suffix = args.output.split(".")
|
|
||||||
name = ".".join(name)
|
|
||||||
x = decoded
|
|
||||||
x = (x * 255).astype(mx.uint8)
|
|
||||||
for i in range(len(x)):
|
|
||||||
im = Image.fromarray(np.array(x[i]))
|
|
||||||
im.save(".".join([name, str(i), suffix]))
|
|
||||||
else:
|
|
||||||
# Arrange them on a grid
|
|
||||||
x = decoded
|
|
||||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
|
||||||
B, H, W, C = x.shape
|
|
||||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
|
||||||
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
|
||||||
x = (x * 255).astype(mx.uint8)
|
|
||||||
|
|
||||||
# Save them to disc
|
|
||||||
im = Image.fromarray(np.array(x))
|
|
||||||
im.save(args.output)
|
|
||||||
|
|
||||||
# Report the peak memory used during generation
|
|
||||||
if args.verbose and group.rank() == 0:
|
|
||||||
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
|
|
||||||
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
|
||||||
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
|
||||||
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
|
||||||
@@ -79,10 +79,10 @@ def load_image(image_source):
|
|||||||
def prepare_inputs(processor, image, prompt):
|
def prepare_inputs(processor, image, prompt):
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = load_image(image)
|
image = load_image(image)
|
||||||
inputs = processor(image, prompt, return_tensors="np")
|
inputs = processor(prompt, image, return_tensors="np")
|
||||||
pixel_values = mx.array(inputs["pixel_values"])
|
pixel_values = mx.array(inputs["pixel_values"])
|
||||||
input_ids = mx.array(inputs["input_ids"])
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
return pixel_values, input_ids
|
return input_ids, pixel_values
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path, tokenizer_config={}):
|
def load_model(model_path, tokenizer_config={}):
|
||||||
@@ -126,7 +126,8 @@ def main():
|
|||||||
processor, model = load_model(args.model, tokenizer_config)
|
processor, model = load_model(args.model, tokenizer_config)
|
||||||
|
|
||||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||||
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)
|
|
||||||
|
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
|
||||||
|
|
||||||
print(prompt)
|
print(prompt)
|
||||||
generated_text = generate_text(
|
generated_text = generate_text(
|
||||||
|
|||||||
@@ -68,10 +68,11 @@ class LlavaModel(nn.Module):
|
|||||||
input_ids: Optional[mx.array] = None,
|
input_ids: Optional[mx.array] = None,
|
||||||
pixel_values: Optional[mx.array] = None,
|
pixel_values: Optional[mx.array] = None,
|
||||||
):
|
):
|
||||||
|
if pixel_values is None:
|
||||||
|
return self.language_model(input_ids)
|
||||||
|
|
||||||
# Get the input embeddings from the language model
|
# Get the input embeddings from the language model
|
||||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
if pixel_values is None:
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
# Get the ouptut hidden states from the vision model
|
# Get the ouptut hidden states from the vision model
|
||||||
*_, hidden_states = self.vision_tower(
|
*_, hidden_states = self.vision_tower(
|
||||||
@@ -104,21 +105,31 @@ class LlavaModel(nn.Module):
|
|||||||
self, image_features, inputs_embeds, input_ids
|
self, image_features, inputs_embeds, input_ids
|
||||||
):
|
):
|
||||||
image_token_index = self.config.image_token_index
|
image_token_index = self.config.image_token_index
|
||||||
batch_size, num_image_patches, embed_dim = image_features.shape
|
num_images, num_image_patches, embed_dim = image_features.shape
|
||||||
|
|
||||||
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
||||||
image_positions = mx.array(
|
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
|
||||||
np.where(input_ids[0] == image_token_index)[0], mx.uint32
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(image_positions) != num_image_patches:
|
if len(image_positions) != num_images:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The number of image tokens ({len(image_positions)}) does not "
|
f"The number of image tokens ({len(image_positions)}) does not "
|
||||||
f" match the number of image patches ({num_image_patches})."
|
f" match the number of image inputs ({num_images})."
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs_embeds[0, image_positions] = image_features
|
text_segments = []
|
||||||
return inputs_embeds
|
start_idx = 0
|
||||||
|
|
||||||
|
for position in image_positions:
|
||||||
|
text_segments.append(inputs_embeds[:, start_idx:position])
|
||||||
|
start_idx = position + 1
|
||||||
|
|
||||||
|
image_embeddings = mx.split(image_features, image_features.shape[0])
|
||||||
|
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
|
||||||
|
final_embeddings += [inputs_embeds[:, start_idx:]]
|
||||||
|
|
||||||
|
# Create a final embedding of shape
|
||||||
|
# (1, num_image_patches*num_images + sequence_len, embed_dim)
|
||||||
|
return mx.concatenate(final_embeddings, axis=1)
|
||||||
|
|
||||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
||||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||||
|
|||||||
47
llms/CONTRIBUTING.md
Normal file
47
llms/CONTRIBUTING.md
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Contributing to MLX LM
|
||||||
|
|
||||||
|
Below are some tips to port LLMs available on Hugging Face to MLX.
|
||||||
|
|
||||||
|
Before starting checkout the [general contribution
|
||||||
|
guidelines](https://github.com/ml-explore/mlx-examples/blob/main/CONTRIBUTING.md).
|
||||||
|
|
||||||
|
Next, from this directory, do an editable install:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
Then check if the model has weights in the
|
||||||
|
[safetensors](https://huggingface.co/docs/safetensors/index) format. If not
|
||||||
|
[follow instructions](https://huggingface.co/spaces/safetensors/convert) to
|
||||||
|
convert it.
|
||||||
|
|
||||||
|
After that, add the model file to the
|
||||||
|
[`mlx_lm/models`](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/models)
|
||||||
|
directory. You can see other examples there. We recommend starting from a model
|
||||||
|
that is similar to the model you are porting.
|
||||||
|
|
||||||
|
Make sure the name of the new model file is the same as the `model_type` in the
|
||||||
|
`config.json`, for example
|
||||||
|
[starcoder2](https://huggingface.co/bigcode/starcoder2-7b/blob/main/config.json#L17).
|
||||||
|
|
||||||
|
To determine the model layer names, we suggest either:
|
||||||
|
|
||||||
|
- Refer to the Transformers implementation if you are familiar with the
|
||||||
|
codebase.
|
||||||
|
- Load the model weights and check the weight names which will tell you about
|
||||||
|
the model structure.
|
||||||
|
- Look at the names of the weights by inspecting `model.safetensors.index.json`
|
||||||
|
in the Hugging Face repo.
|
||||||
|
|
||||||
|
To add LoRA support edit
|
||||||
|
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
|
||||||
|
|
||||||
|
Finally, add a test for the new modle type to the [model
|
||||||
|
tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py).
|
||||||
|
|
||||||
|
From the `llms/` directory, you can run the tests with:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m unittest discover tests/
|
||||||
|
```
|
||||||
2
llms/MANIFEST.in
Normal file
2
llms/MANIFEST.in
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
include mlx_lm/requirements.txt
|
||||||
|
recursive-include mlx_lm/ *.py
|
||||||
171
llms/README.md
171
llms/README.md
@@ -1,6 +1,169 @@
|
|||||||
# MOVE NOTICE
|
## Generate Text with LLMs and MLX
|
||||||
|
|
||||||
The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm).
|
The easiest way to get started is to install the `mlx-lm` package:
|
||||||
|
|
||||||
The package has been removed from the MLX Examples repo. Send new contributions
|
**With `pip`**:
|
||||||
and issues to the MLX LM repo.
|
|
||||||
|
```sh
|
||||||
|
pip install mlx-lm
|
||||||
|
```
|
||||||
|
|
||||||
|
**With `conda`**:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
conda install -c conda-forge mlx-lm
|
||||||
|
```
|
||||||
|
|
||||||
|
The `mlx-lm` package also has:
|
||||||
|
|
||||||
|
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||||
|
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||||
|
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||||
|
|
||||||
|
### Python API
|
||||||
|
|
||||||
|
You can use `mlx-lm` as a module:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlx_lm import load, generate
|
||||||
|
|
||||||
|
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||||
|
|
||||||
|
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
To see a description of all the arguments you can do:
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> help(generate)
|
||||||
|
```
|
||||||
|
|
||||||
|
The `mlx-lm` package also comes with functionality to quantize and optionally
|
||||||
|
upload models to the Hugging Face Hub.
|
||||||
|
|
||||||
|
You can convert models in the Python API with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlx_lm import convert
|
||||||
|
|
||||||
|
repo = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||||
|
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
|
||||||
|
|
||||||
|
convert(repo, quantize=True, upload_repo=upload_repo)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will generate a 4-bit quantized Mistral 7B and upload it to the repo
|
||||||
|
`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
|
||||||
|
converted model in the path `mlx_model` by default.
|
||||||
|
|
||||||
|
To see a description of all the arguments you can do:
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> help(convert)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Streaming
|
||||||
|
|
||||||
|
For streaming generation, use the `stream_generate` function. This returns a
|
||||||
|
generator object which streams the output text. For example,
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlx_lm import load, stream_generate
|
||||||
|
|
||||||
|
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||||
|
model, tokenizer = load(repo)
|
||||||
|
|
||||||
|
prompt = "Write a story about Einstein"
|
||||||
|
|
||||||
|
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||||
|
print(t, end="", flush=True)
|
||||||
|
print()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Line
|
||||||
|
|
||||||
|
You can also use `mlx-lm` from the command line with:
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
|
||||||
|
```
|
||||||
|
|
||||||
|
This will download a Mistral 7B model from the Hugging Face Hub and generate
|
||||||
|
text using the given prompt.
|
||||||
|
|
||||||
|
For a full list of options run:
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_lm.generate --help
|
||||||
|
```
|
||||||
|
|
||||||
|
To quantize a model from the command line run:
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
|
||||||
|
```
|
||||||
|
|
||||||
|
For more options run:
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_lm.convert --help
|
||||||
|
```
|
||||||
|
|
||||||
|
You can upload new models to Hugging Face by specifying `--upload-repo` to
|
||||||
|
`convert`. For example, to upload a quantized Mistral-7B model to the
|
||||||
|
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_lm.convert \
|
||||||
|
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
|
||||||
|
-q \
|
||||||
|
--upload-repo mlx-community/my-4bit-mistral
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Models
|
||||||
|
|
||||||
|
The example supports Hugging Face format Mistral, Llama, and Phi-2 style
|
||||||
|
models. If the model you want to run is not supported, file an
|
||||||
|
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||||
|
submit a pull request.
|
||||||
|
|
||||||
|
Here are a few examples of Hugging Face models that work with this example:
|
||||||
|
|
||||||
|
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||||
|
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
|
||||||
|
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
|
||||||
|
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
|
||||||
|
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
|
||||||
|
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||||
|
- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
|
||||||
|
- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b)
|
||||||
|
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||||
|
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||||
|
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||||
|
|
||||||
|
Most
|
||||||
|
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||||
|
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
|
||||||
|
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
|
||||||
|
and
|
||||||
|
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
|
||||||
|
style models should work out of the box.
|
||||||
|
|
||||||
|
For some models (such as `Qwen` and `plamo`) the tokenizer requires you to
|
||||||
|
enable the `trust_remote_code` option. You can do this by passing
|
||||||
|
`--trust-remote-code` in the command line. If you don't specify the flag
|
||||||
|
explicitly, you will be prompted to trust remote code in the terminal when
|
||||||
|
running the model.
|
||||||
|
|
||||||
|
For `Qwen` models you must also specify the `eos_token`. You can do this by
|
||||||
|
passing `--eos-token "<|endoftext|>"` in the command
|
||||||
|
line.
|
||||||
|
|
||||||
|
These options can also be set in the Python API. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model, tokenizer = load(
|
||||||
|
"qwen/Qwen-7B",
|
||||||
|
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def generate(
|
|||||||
if len(tokens) == 0:
|
if len(tokens) == 0:
|
||||||
print("No tokens generated for this prompt")
|
print("No tokens generated for this prompt")
|
||||||
return
|
return
|
||||||
prompt_tps = len(prompt) / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
gen_tps = (len(tokens) - 1) / gen_time
|
gen_tps = (len(tokens) - 1) / gen_time
|
||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ class ModelArgs:
|
|||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
context_length: int
|
context_length: int
|
||||||
num_key_value_heads: Optional[int] = None
|
num_key_value_heads: int = None
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
model_type: Optional[str] = None
|
model_type: str = None
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -54,7 +54,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
self.repeats = n_heads // n_kv_heads
|
self.repeats = n_heads // n_kv_heads
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
rope_scale = (
|
rope_scale = (
|
||||||
1 / float(args.rope_scaling["factor"])
|
1 / args.rope_scaling["factor"]
|
||||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
else 1
|
else 1
|
||||||
)
|
)
|
||||||
@@ -254,7 +254,7 @@ def translate_weight_names(name):
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def load(gguf_file: str, repo: Optional[str] = None):
|
def load(gguf_file: str, repo: str = None):
|
||||||
# If the gguf_file exists, try to load model from it.
|
# If the gguf_file exists, try to load model from it.
|
||||||
# Otherwise try to download and cache from the HF repo
|
# Otherwise try to download and cache from the HF repo
|
||||||
if not Path(gguf_file).exists():
|
if not Path(gguf_file).exists():
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import glob
|
|||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -150,8 +149,7 @@ def quantize(weights, config, args):
|
|||||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||||
shards = []
|
shards = []
|
||||||
shard: Dict[str, mx.array] = {}
|
shard, shard_size = {}, 0
|
||||||
shard_size = 0
|
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if shard_size + v.nbytes > max_file_size_bytes:
|
if shard_size + v.nbytes > max_file_size_bytes:
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class ModelArgs:
|
|||||||
n_kv_heads: int
|
n_kv_heads: int
|
||||||
norm_eps: float
|
norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
moe: dict
|
moe: dict = None
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@@ -91,6 +91,7 @@ class FeedForward(nn.Module):
|
|||||||
class MOEFeedForward(nn.Module):
|
class MOEFeedForward(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_experts = args.moe["num_experts"]
|
self.num_experts = args.moe["num_experts"]
|
||||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||||
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
||||||
@@ -114,6 +115,7 @@ class MOEFeedForward(nn.Module):
|
|||||||
yt = (yt * st).sum(axis=-1)
|
yt = (yt * st).sum(axis=-1)
|
||||||
y.append(yt[None, :])
|
y.append(yt[None, :])
|
||||||
y = mx.concatenate(y)
|
y = mx.concatenate(y)
|
||||||
|
|
||||||
return y.reshape(orig_shape)
|
return y.reshape(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
266
llms/mlx_lm/LORA.md
Normal file
266
llms/mlx_lm/LORA.md
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
# Fine-Tuning with LoRA or QLoRA
|
||||||
|
|
||||||
|
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
|
||||||
|
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
|
||||||
|
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||||
|
|
||||||
|
- Mistral
|
||||||
|
- Llama
|
||||||
|
- Phi2
|
||||||
|
- Mixtral
|
||||||
|
- Qwen2
|
||||||
|
- Gemma
|
||||||
|
- OLMo
|
||||||
|
- MiniCPM
|
||||||
|
- InternLM2
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
|
||||||
|
- [Run](#Run)
|
||||||
|
- [Fine-tune](#Fine-tune)
|
||||||
|
- [Evaluate](#Evaluate)
|
||||||
|
- [Generate](#Generate)
|
||||||
|
- [Fuse](#Fuse)
|
||||||
|
- [Data](#Data)
|
||||||
|
- [Memory Issues](#Memory-Issues)
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.lora --help
|
||||||
|
```
|
||||||
|
|
||||||
|
Note, in the following the `--model` argument can be any compatible Hugging
|
||||||
|
Face repo or a local path to a converted model.
|
||||||
|
|
||||||
|
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
|
||||||
|
[example YAML](examples/lora_config.yaml). For example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.lora --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
If command-line flags are also used, they will override the corresponding
|
||||||
|
values in the config.
|
||||||
|
|
||||||
|
### Fine-tune
|
||||||
|
|
||||||
|
To fine-tune a model use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.lora \
|
||||||
|
--model <path_to_model> \
|
||||||
|
--train \
|
||||||
|
--data <path_to_data> \
|
||||||
|
--iters 600
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||||
|
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||||
|
details on the data format see the section on [Data](#Data).
|
||||||
|
|
||||||
|
For example, to fine-tune a Mistral 7B you can use `--model
|
||||||
|
mistralai/Mistral-7B-v0.1`.
|
||||||
|
|
||||||
|
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||||
|
otherwise it will use regular LoRA.
|
||||||
|
|
||||||
|
By default, the adapter config and weights are saved in `adapters/`. You can
|
||||||
|
specify the output location with `--adapter-path`.
|
||||||
|
|
||||||
|
You can resume fine-tuning with an existing adapter with
|
||||||
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
|
|
||||||
|
### Evaluate
|
||||||
|
|
||||||
|
To compute test set perplexity use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.lora \
|
||||||
|
--model <path_to_model> \
|
||||||
|
--adapter-path <path_to_adapters> \
|
||||||
|
--data <path_to_data> \
|
||||||
|
--test
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generate
|
||||||
|
|
||||||
|
For generation use `mlx_lm.generate`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.generate \
|
||||||
|
--model <path_to_model> \
|
||||||
|
--adapter-path <path_to_adapters> \
|
||||||
|
--prompt "<your_model_prompt>"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fuse
|
||||||
|
|
||||||
|
You can generate a model fused with the low-rank adapters using the
|
||||||
|
`mlx_lm.fuse` command. This command also allows you to optionally:
|
||||||
|
|
||||||
|
- Upload the fused model to the Hugging Face Hub.
|
||||||
|
- Export the fused model to GGUF. Note GGUF support is limited to Mistral,
|
||||||
|
Mixtral, and Llama style models in fp16 precision.
|
||||||
|
|
||||||
|
To see supported options run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.fuse --help
|
||||||
|
```
|
||||||
|
|
||||||
|
To generate the fused model run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.fuse --model <path_to_model>
|
||||||
|
```
|
||||||
|
|
||||||
|
This will by default load the adapters from `adapters/`, and save the fused
|
||||||
|
model in the path `lora_fused_model/`. All of these are configurable.
|
||||||
|
|
||||||
|
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||||
|
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||||
|
useful for the sake of attribution and model versioning.
|
||||||
|
|
||||||
|
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.fuse \
|
||||||
|
--model mistralai/Mistral-7B-v0.1 \
|
||||||
|
--upload-repo mlx-community/my-lora-mistral-7b \
|
||||||
|
--hf-path mistralai/Mistral-7B-v0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
To export a fused model to GGUF, run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.fuse \
|
||||||
|
--model mistralai/Mistral-7B-v0.1 \
|
||||||
|
--export-gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
|
||||||
|
can specify the file name with `--gguf-path`.
|
||||||
|
|
||||||
|
## Data
|
||||||
|
|
||||||
|
The LoRA command expects you to provide a dataset with `--data`. The MLX
|
||||||
|
Examples GitHub repo has an [example of the WikiSQL
|
||||||
|
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
|
||||||
|
correct format.
|
||||||
|
|
||||||
|
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||||
|
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||||
|
loader expects a `test.jsonl` in the data directory.
|
||||||
|
|
||||||
|
Currently, `*.jsonl` files support three data formats: `chat`,
|
||||||
|
`completions`, and `text`. Here are three examples of these formats:
|
||||||
|
|
||||||
|
`chat`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "How can I assistant you today."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`completions`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{
|
||||||
|
"prompt": "What is the capital of France?",
|
||||||
|
"completion": "Paris."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`text`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{
|
||||||
|
"text": "This is an example for the model."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note, the format is automatically determined by the dataset. Note also, keys in
|
||||||
|
each line not expected by the loader will be ignored.
|
||||||
|
|
||||||
|
For the `chat` and `completions` formats, Hugging Face [chat
|
||||||
|
templates](https://huggingface.co/blog/chat-templates) are used. This applies
|
||||||
|
the model's chat template by default. If the model does not have a chat
|
||||||
|
template, then Hugging Face will use a default. For example, the final text in
|
||||||
|
the `chat` example above with Hugging Face's default template becomes:
|
||||||
|
|
||||||
|
```text
|
||||||
|
<|im_start|>system
|
||||||
|
You are a helpful assistant.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
Hello.<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
How can I assistant you today.<|im_end|>
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are unsure of the format to use, the `chat` or `completions` are good to
|
||||||
|
start with. For custom requirements on the format of the dataset, use the
|
||||||
|
`text` format to assemble the content yourself.
|
||||||
|
|
||||||
|
## Memory Issues
|
||||||
|
|
||||||
|
Fine-tuning a large model with LoRA requires a machine with a decent amount
|
||||||
|
of memory. Here are some tips to reduce memory use should you need to do so:
|
||||||
|
|
||||||
|
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
|
||||||
|
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
|
||||||
|
more details.
|
||||||
|
|
||||||
|
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
|
||||||
|
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||||
|
things down a little, but will also reduce the memory use.
|
||||||
|
|
||||||
|
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
||||||
|
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||||
|
needed for back propagation. It may also reduce the quality of the
|
||||||
|
fine-tuned model if you are fine-tuning with a lot of data.
|
||||||
|
|
||||||
|
4. Longer examples require more memory. If it makes sense for your data, one thing
|
||||||
|
you can do is break your examples into smaller
|
||||||
|
sequences when making the `{train, valid, test}.jsonl` files.
|
||||||
|
|
||||||
|
5. Gradient checkpointing lets you trade-off memory use (less) for computation
|
||||||
|
(more) by recomputing instead of storing intermediate values needed by the
|
||||||
|
backward pass. You can use gradient checkpointing by passing the
|
||||||
|
`--grad-checkpoint` flag. Gradient checkpointing will be more helpful for
|
||||||
|
larger batch sizes or sequence lengths with smaller or quantized models.
|
||||||
|
|
||||||
|
For example, for a machine with 32 GB the following should run reasonably fast:
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_lm.lora \
|
||||||
|
--model mistralai/Mistral-7B-v0.1 \
|
||||||
|
--train \
|
||||||
|
--batch-size 1 \
|
||||||
|
--lora-layers 4 \
|
||||||
|
--data wikisql
|
||||||
|
```
|
||||||
|
|
||||||
|
The above command on an M1 Max with 32 GB runs at about 250
|
||||||
|
tokens-per-second, using the MLX Example
|
||||||
|
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
|
||||||
|
data set.
|
||||||
|
|
||||||
|
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||||
|
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||||
22
llms/mlx_lm/MANAGE.md
Normal file
22
llms/mlx_lm/MANAGE.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Managing Models
|
||||||
|
|
||||||
|
You can use `mlx-lm` to manage models downloaded locally in your machine. They
|
||||||
|
are stored in the Hugging Face cache.
|
||||||
|
|
||||||
|
Scan models:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.manage --scan
|
||||||
|
```
|
||||||
|
|
||||||
|
Specify a `--pattern` to get info on a single or specific set of models:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
|
||||||
|
```
|
||||||
|
|
||||||
|
To delete a model (or multiple models):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
|
||||||
|
```
|
||||||
50
llms/mlx_lm/MERGE.md
Normal file
50
llms/mlx_lm/MERGE.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Model Merging
|
||||||
|
|
||||||
|
You can use `mlx-lm` to merge models and upload them to the Hugging
|
||||||
|
Face hub or save them locally for LoRA fine tuning.
|
||||||
|
|
||||||
|
The main command is `mlx_lm.merge`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.merge --config config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
The merged model will be saved by default in `mlx_merged_model`. To see a
|
||||||
|
full list of options run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.merge --help
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is an example `config.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
- OpenPipe/mistral-ft-optimized-1218
|
||||||
|
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||||
|
method: slerp
|
||||||
|
parameters:
|
||||||
|
t:
|
||||||
|
- filter: self_attn
|
||||||
|
value: [0, 0.5, 0.3, 0.7, 1]
|
||||||
|
- filter: mlp
|
||||||
|
value: [1, 0.5, 0.7, 0.3, 0]
|
||||||
|
- value: 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
The `models` field is a list of Hugging Face repo ids. The first model in the
|
||||||
|
list is treated as the base model into which the remaining models are merged.
|
||||||
|
|
||||||
|
The `method` field is the merging method. Right now `slerp` is the only
|
||||||
|
supported method.
|
||||||
|
|
||||||
|
The `parameters` are the corresponding parameters for the given `method`.
|
||||||
|
Each parameter is a list with `filter` determining which layer the parameter
|
||||||
|
applies to and `value` determining the actual value used. The last item in
|
||||||
|
the list without a `filter` field is the default.
|
||||||
|
|
||||||
|
If `value` is a list, it specifies the start and end values for the
|
||||||
|
corresponding segment of blocks. In the example above, the models have 32
|
||||||
|
blocks. For blocks 1-8, the layers with `self_attn` in the name will use the
|
||||||
|
values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16)
|
||||||
|
will use `np.linspace(0.5, 0.3, 8)`, and so on.
|
||||||
10
llms/mlx_lm/README.md
Normal file
10
llms/mlx_lm/README.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
## Generate Text with MLX and :hugs: Hugging Face
|
||||||
|
|
||||||
|
This an example of large language model text generation that can pull models from
|
||||||
|
the Hugging Face Hub.
|
||||||
|
|
||||||
|
For more information on this example, see the [README](../README.md) in the
|
||||||
|
parent directory.
|
||||||
|
|
||||||
|
This package also supports fine tuning with LoRA or QLoRA. For more information
|
||||||
|
see the [LoRA documentation](LORA.md).
|
||||||
76
llms/mlx_lm/SERVER.md
Normal file
76
llms/mlx_lm/SERVER.md
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# HTTP Model Server
|
||||||
|
|
||||||
|
You use `mlx-lm` to make an HTTP API for generating text with any supported
|
||||||
|
model. The HTTP API is intended to be similar to the [OpenAI chat
|
||||||
|
API](https://platform.openai.com/docs/api-reference).
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> The MLX LM server is not recommended for production as it only implements
|
||||||
|
> basic security checks.
|
||||||
|
|
||||||
|
Start the server with:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.server --model <path_to_model_or_hf_repo>
|
||||||
|
```
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
This will start a text generation server on port `8080` of the `localhost`
|
||||||
|
using Mistral 7B instruct. The model will be downloaded from the provided
|
||||||
|
Hugging Face repo if it is not already in the local cache.
|
||||||
|
|
||||||
|
To see a full list of options run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx_lm.server --help
|
||||||
|
```
|
||||||
|
|
||||||
|
You can make a request to the model by running:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl localhost:8080/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||||
|
"temperature": 0.7
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Request Fields
|
||||||
|
|
||||||
|
- `messages`: An array of message objects representing the conversation
|
||||||
|
history. Each message object should have a role (e.g. user, assistant) and
|
||||||
|
content (the message text).
|
||||||
|
|
||||||
|
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
||||||
|
the generated prompt. If not provided, the default mappings are used.
|
||||||
|
|
||||||
|
- `stop`: (Optional) An array of strings or a single string. Thesse are
|
||||||
|
sequences of tokens on which the generation should stop.
|
||||||
|
|
||||||
|
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
||||||
|
to generate. Defaults to `100`.
|
||||||
|
|
||||||
|
- `stream`: (Optional) A boolean indicating if the response should be
|
||||||
|
streamed. If true, responses are sent as they are generated. Defaults to
|
||||||
|
false.
|
||||||
|
|
||||||
|
- `temperature`: (Optional) A float specifying the sampling temperature.
|
||||||
|
Defaults to `1.0`.
|
||||||
|
|
||||||
|
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
|
||||||
|
Defaults to `1.0`.
|
||||||
|
|
||||||
|
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
|
||||||
|
Defaults to `1.0`.
|
||||||
|
|
||||||
|
- `repetition_context_size`: (Optional) The size of the context window for
|
||||||
|
applying repetition penalty. Defaults to `20`.
|
||||||
|
|
||||||
|
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
||||||
|
values. Defaults to `None`.
|
||||||
37
llms/mlx_lm/UPLOAD.md
Normal file
37
llms/mlx_lm/UPLOAD.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
### Packaging for PyPI
|
||||||
|
|
||||||
|
Install `build` and `twine`:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install --user --upgrade build
|
||||||
|
pip install --user --upgrade twine
|
||||||
|
```
|
||||||
|
|
||||||
|
Generate the source distribution and wheel:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m build
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!warning]
|
||||||
|
> Use a test server first
|
||||||
|
|
||||||
|
#### Test Upload
|
||||||
|
|
||||||
|
Upload to test server:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m twine upload --repository testpypi dist/*
|
||||||
|
```
|
||||||
|
|
||||||
|
Install from test server and check that it works:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Upload
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m twine upload dist/*
|
||||||
|
```
|
||||||
4
llms/mlx_lm/__init__.py
Normal file
4
llms/mlx_lm/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from .utils import convert, generate, load, stream_generate
|
||||||
|
from .version import __version__
|
||||||
62
llms/mlx_lm/convert.py
Normal file
62
llms/mlx_lm/convert.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from .utils import convert
|
||||||
|
|
||||||
|
|
||||||
|
def configure_parser() -> argparse.ArgumentParser:
|
||||||
|
"""
|
||||||
|
Configures and returns the argument parser for the script.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.ArgumentParser: Configured argument parser.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert Hugging Face model to MLX format"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
help="Type to save the parameters, ignored if -q is given.",
|
||||||
|
type=str,
|
||||||
|
choices=["float16", "bfloat16", "float32"],
|
||||||
|
default="float16",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload-repo",
|
||||||
|
help="The Hugging Face repo to upload the model to.",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--dequantize",
|
||||||
|
help="Dequantize a quantized model.",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = configure_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
71
llms/mlx_lm/examples/lora_config.yaml
Normal file
71
llms/mlx_lm/examples/lora_config.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
# The path to the local model directory or Hugging Face repo.
|
||||||
|
model: "mlx_model"
|
||||||
|
# Whether or not to train (boolean)
|
||||||
|
train: true
|
||||||
|
|
||||||
|
# Directory with {train, valid, test}.jsonl files
|
||||||
|
data: "/path/to/training/data"
|
||||||
|
|
||||||
|
# The PRNG seed
|
||||||
|
seed: 0
|
||||||
|
|
||||||
|
# Number of layers to fine-tune
|
||||||
|
lora_layers: 16
|
||||||
|
|
||||||
|
# Minibatch size.
|
||||||
|
batch_size: 4
|
||||||
|
|
||||||
|
# Iterations to train for.
|
||||||
|
iters: 1000
|
||||||
|
|
||||||
|
# Number of validation batches, -1 uses the entire validation set.
|
||||||
|
val_batches: 25
|
||||||
|
|
||||||
|
# Adam learning rate.
|
||||||
|
learning_rate: 1e-5
|
||||||
|
|
||||||
|
# Number of training steps between loss reporting.
|
||||||
|
steps_per_report: 10
|
||||||
|
|
||||||
|
# Number of training steps between validations.
|
||||||
|
steps_per_eval: 200
|
||||||
|
|
||||||
|
# Load path to resume training with the given adapter weights.
|
||||||
|
resume_adapter_file: null
|
||||||
|
|
||||||
|
# Save/load path for the trained adapter weights.
|
||||||
|
adapter_path: "adapters"
|
||||||
|
|
||||||
|
# Save the model every N iterations.
|
||||||
|
save_every: 100
|
||||||
|
|
||||||
|
# Evaluate on the test set after training
|
||||||
|
test: false
|
||||||
|
|
||||||
|
# Number of test set batches, -1 uses the entire test set.
|
||||||
|
test_batches: 100
|
||||||
|
|
||||||
|
# Maximum sequence length.
|
||||||
|
max_seq_length: 2048
|
||||||
|
|
||||||
|
# Use gradient checkpointing to reduce memory use.
|
||||||
|
grad_checkpoint: false
|
||||||
|
|
||||||
|
# Use DoRA instead of LoRA.
|
||||||
|
use_dora: false
|
||||||
|
|
||||||
|
# LoRA parameters can only be specified in a config file
|
||||||
|
lora_parameters:
|
||||||
|
# The layer keys to apply LoRA to.
|
||||||
|
# These will be applied for the last lora_layers
|
||||||
|
keys: ["self_attn.q_proj", "self_attn.v_proj"]
|
||||||
|
rank: 8
|
||||||
|
scale: 20.0
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
# Schedule can only be specified in a config file, uncomment to use.
|
||||||
|
#lr_schedule:
|
||||||
|
# name: cosine_decay
|
||||||
|
# warmup: 100 # 0 for no warmup
|
||||||
|
# warmup_init: 1e-7 # 0 if not specified
|
||||||
|
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
|
||||||
11
llms/mlx_lm/examples/merge_config.yaml
Normal file
11
llms/mlx_lm/examples/merge_config.yaml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
models:
|
||||||
|
- OpenPipe/mistral-ft-optimized-1218
|
||||||
|
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||||
|
method: slerp
|
||||||
|
parameters:
|
||||||
|
t:
|
||||||
|
- filter: self_attn
|
||||||
|
value: [0, 0.5, 0.3, 0.7, 1]
|
||||||
|
- filter: mlp
|
||||||
|
value: [1, 0.5, 0.7, 0.3, 0]
|
||||||
|
- value: 0.5
|
||||||
131
llms/mlx_lm/fuse.py
Normal file
131
llms/mlx_lm/fuse.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
|
from .gguf import convert_to_gguf
|
||||||
|
from .tuner.dora import DoRALinear
|
||||||
|
from .tuner.lora import LoRALinear, LoRASwitchLinear
|
||||||
|
from .tuner.utils import apply_lora_layers, dequantize
|
||||||
|
from .utils import (
|
||||||
|
fetch_from_hub,
|
||||||
|
get_model_path,
|
||||||
|
save_config,
|
||||||
|
save_weights,
|
||||||
|
upload_to_hub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Fuse fine-tuned adapters into the base model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="mlx_model",
|
||||||
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
default="lora_fused_model",
|
||||||
|
help="The path to save the fused model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-path",
|
||||||
|
type=str,
|
||||||
|
default="adapters",
|
||||||
|
help="Path to the trained adapter weights and config.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload-repo",
|
||||||
|
help="The Hugging Face repo to upload the model to.",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--de-quantize",
|
||||||
|
help="Generate a de-quantized model.",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--export-gguf",
|
||||||
|
help="Export model weights in GGUF format.",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gguf-path",
|
||||||
|
help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.",
|
||||||
|
default="ggml-model-f16.gguf",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
print("Loading pretrained model")
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
model_path = get_model_path(args.model)
|
||||||
|
model, config, tokenizer = fetch_from_hub(model_path)
|
||||||
|
|
||||||
|
model.freeze()
|
||||||
|
model = apply_lora_layers(model, args.adapter_path)
|
||||||
|
|
||||||
|
fused_linears = [
|
||||||
|
(n, m.to_linear())
|
||||||
|
for n, m in model.named_modules()
|
||||||
|
if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear))
|
||||||
|
]
|
||||||
|
|
||||||
|
model.update_modules(tree_unflatten(fused_linears))
|
||||||
|
|
||||||
|
if args.de_quantize:
|
||||||
|
print("De-quantizing model")
|
||||||
|
model = dequantize(model)
|
||||||
|
|
||||||
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
|
save_path = Path(args.save_path)
|
||||||
|
|
||||||
|
save_weights(save_path, weights)
|
||||||
|
|
||||||
|
py_files = glob.glob(str(model_path / "*.py"))
|
||||||
|
for file in py_files:
|
||||||
|
shutil.copy(file, save_path)
|
||||||
|
|
||||||
|
tokenizer.save_pretrained(save_path)
|
||||||
|
|
||||||
|
if args.de_quantize:
|
||||||
|
config.pop("quantization", None)
|
||||||
|
|
||||||
|
save_config(config, config_path=save_path / "config.json")
|
||||||
|
|
||||||
|
if args.export_gguf:
|
||||||
|
model_type = config["model_type"]
|
||||||
|
if model_type not in ["llama", "mixtral", "mistral"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model type {model_type} not supported for GGUF conversion."
|
||||||
|
)
|
||||||
|
convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
|
||||||
|
|
||||||
|
if args.upload_repo is not None:
|
||||||
|
hf_path = args.hf_path or (
|
||||||
|
args.model if not Path(args.model).exists() else None
|
||||||
|
)
|
||||||
|
if hf_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Must provide original Hugging Face repo to upload local model."
|
||||||
|
)
|
||||||
|
upload_to_hub(args.save_path, args.upload_repo, hf_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
161
llms/mlx_lm/generate.py
Normal file
161
llms/mlx_lm/generate.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .utils import generate, load
|
||||||
|
|
||||||
|
DEFAULT_MODEL_PATH = "mlx_model"
|
||||||
|
DEFAULT_PROMPT = "hello"
|
||||||
|
DEFAULT_MAX_TOKENS = 100
|
||||||
|
DEFAULT_TEMP = 0.6
|
||||||
|
DEFAULT_TOP_P = 1.0
|
||||||
|
DEFAULT_SEED = 0
|
||||||
|
|
||||||
|
|
||||||
|
def setup_arg_parser():
|
||||||
|
"""Set up and return the argument parser."""
|
||||||
|
parser = argparse.ArgumentParser(description="LLM inference script")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="mlx_model",
|
||||||
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-path",
|
||||||
|
type=str,
|
||||||
|
help="Optional path for the trained adapter weights and config.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable trusting remote code for tokenizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eos-token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="End of sequence token for tokenizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_MAX_TOKENS,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ignore-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the raw prompt without the tokenizer's chat template.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-default-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the default chat template",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--colorize",
|
||||||
|
action="store_true",
|
||||||
|
help="Colorize output based on T[0] probability",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-limit-gb",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Set the MLX cache limit in GB",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def colorprint(color, s):
|
||||||
|
color_codes = {
|
||||||
|
"black": 30,
|
||||||
|
"red": 31,
|
||||||
|
"green": 32,
|
||||||
|
"yellow": 33,
|
||||||
|
"blue": 34,
|
||||||
|
"magenta": 35,
|
||||||
|
"cyan": 36,
|
||||||
|
"white": 39,
|
||||||
|
}
|
||||||
|
ccode = color_codes.get(color, 30)
|
||||||
|
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def colorprint_by_t0(s, t0):
|
||||||
|
if t0 > 0.95:
|
||||||
|
color = "white"
|
||||||
|
elif t0 > 0.70:
|
||||||
|
color = "green"
|
||||||
|
elif t0 > 0.30:
|
||||||
|
color = "yellow"
|
||||||
|
else:
|
||||||
|
color = "red"
|
||||||
|
colorprint(color, s)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = setup_arg_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
if args.cache_limit_gb is not None:
|
||||||
|
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
# Building tokenizer_config
|
||||||
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
|
if args.eos_token is not None:
|
||||||
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
|
model, tokenizer = load(
|
||||||
|
args.model,
|
||||||
|
adapter_path=args.adapter_path,
|
||||||
|
tokenizer_config=tokenizer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_default_chat_template:
|
||||||
|
if tokenizer.chat_template is None:
|
||||||
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
|
|
||||||
|
if not args.ignore_chat_template and (
|
||||||
|
hasattr(tokenizer, "apply_chat_template")
|
||||||
|
and tokenizer.chat_template is not None
|
||||||
|
):
|
||||||
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = args.prompt
|
||||||
|
|
||||||
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
|
|
||||||
|
generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
args.max_tokens,
|
||||||
|
verbose=True,
|
||||||
|
formatter=formatter,
|
||||||
|
temp=args.temp,
|
||||||
|
top_p=args.top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
313
llms/mlx_lm/gguf.py
Normal file
313
llms/mlx_lm/gguf.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import re
|
||||||
|
from enum import IntEnum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TokenType(IntEnum):
|
||||||
|
NORMAL = 1
|
||||||
|
UNKNOWN = 2
|
||||||
|
CONTROL = 3
|
||||||
|
USER_DEFINED = 4
|
||||||
|
UNUSED = 5
|
||||||
|
BYTE = 6
|
||||||
|
|
||||||
|
|
||||||
|
class GGMLFileType(IntEnum):
|
||||||
|
GGML_TYPE_F16 = 1
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
|
||||||
|
class HfVocab:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fname_tokenizer: Path,
|
||||||
|
fname_added_tokens: Optional[Union[Path, None]] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
fname_tokenizer,
|
||||||
|
cache_dir=fname_tokenizer,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
self.added_tokens_list = []
|
||||||
|
self.added_tokens_dict = dict()
|
||||||
|
self.added_tokens_ids = set()
|
||||||
|
for tok, tokidx in sorted(
|
||||||
|
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
|
||||||
|
):
|
||||||
|
if tokidx >= self.tokenizer.vocab_size:
|
||||||
|
self.added_tokens_list.append(tok)
|
||||||
|
self.added_tokens_dict[tok] = tokidx
|
||||||
|
self.added_tokens_ids.add(tokidx)
|
||||||
|
self.specials = {
|
||||||
|
tok: self.tokenizer.get_vocab()[tok]
|
||||||
|
for tok in self.tokenizer.all_special_tokens
|
||||||
|
}
|
||||||
|
self.special_ids = set(self.tokenizer.all_special_ids)
|
||||||
|
self.vocab_size_base = self.tokenizer.vocab_size
|
||||||
|
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
||||||
|
self.fname_tokenizer = fname_tokenizer
|
||||||
|
self.fname_added_tokens = fname_added_tokens
|
||||||
|
|
||||||
|
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||||
|
reverse_vocab = {
|
||||||
|
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||||
|
}
|
||||||
|
for token_id in range(self.vocab_size_base):
|
||||||
|
if token_id in self.added_tokens_ids:
|
||||||
|
continue
|
||||||
|
token_text = reverse_vocab[token_id].encode("utf-8")
|
||||||
|
yield token_text, self.get_token_score(token_id), self.get_token_type(
|
||||||
|
token_id, token_text, self.special_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_token_type(
|
||||||
|
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||||
|
) -> TokenType:
|
||||||
|
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||||
|
return TokenType.BYTE
|
||||||
|
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
|
||||||
|
|
||||||
|
def get_token_score(self, token_id: int) -> float:
|
||||||
|
return -1000.0
|
||||||
|
|
||||||
|
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||||
|
for text in self.added_tokens_list:
|
||||||
|
if text in self.specials:
|
||||||
|
toktype = self.get_token_type(
|
||||||
|
self.specials[text], b"", self.special_ids
|
||||||
|
)
|
||||||
|
score = self.get_token_score(self.specials[text])
|
||||||
|
else:
|
||||||
|
toktype = TokenType.USER_DEFINED
|
||||||
|
score = -1000.0
|
||||||
|
yield text.encode("utf-8"), score, toktype
|
||||||
|
|
||||||
|
def has_newline_token(self):
|
||||||
|
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||||
|
|
||||||
|
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||||
|
yield from self.hf_tokens()
|
||||||
|
yield from self.added_tokens()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path: Path) -> "HfVocab":
|
||||||
|
added_tokens_path = path.parent / "added_tokens.json"
|
||||||
|
return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_weight_names(name):
|
||||||
|
name = name.replace("model.layers.", "blk.")
|
||||||
|
# for mixtral gate
|
||||||
|
name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
|
||||||
|
# for mixtral experts ffns
|
||||||
|
pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight"
|
||||||
|
replacement = r"ffn_gate.\1.weight"
|
||||||
|
name = re.sub(pattern, replacement, name)
|
||||||
|
pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight"
|
||||||
|
replacement = r"ffn_down.\1.weight"
|
||||||
|
name = re.sub(pattern, replacement, name)
|
||||||
|
pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight"
|
||||||
|
replacement = r"ffn_up.\1.weight"
|
||||||
|
name = re.sub(pattern, replacement, name)
|
||||||
|
|
||||||
|
name = name.replace("mlp.gate_proj", "ffn_gate")
|
||||||
|
name = name.replace("mlp.down_proj", "ffn_down")
|
||||||
|
name = name.replace("mlp.up_proj", "ffn_up")
|
||||||
|
name = name.replace("self_attn.q_proj", "attn_q")
|
||||||
|
name = name.replace("self_attn.k_proj", "attn_k")
|
||||||
|
name = name.replace("self_attn.v_proj", "attn_v")
|
||||||
|
name = name.replace("self_attn.o_proj", "attn_output")
|
||||||
|
name = name.replace("input_layernorm", "attn_norm")
|
||||||
|
name = name.replace("post_attention_layernorm", "ffn_norm")
|
||||||
|
name = name.replace("model.embed_tokens", "token_embd")
|
||||||
|
name = name.replace("model.norm", "output_norm")
|
||||||
|
name = name.replace("lm_head", "output")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def permute_weights(weights, n_head, n_head_kv=None):
|
||||||
|
if n_head_kv is not None and n_head != n_head_kv:
|
||||||
|
n_head = n_head_kv
|
||||||
|
reshaped = weights.reshape(
|
||||||
|
n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
|
||||||
|
)
|
||||||
|
swapped = reshaped.swapaxes(1, 2)
|
||||||
|
final_shape = weights.shape
|
||||||
|
return swapped.reshape(final_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_metadata(config, vocab):
|
||||||
|
metadata = {
|
||||||
|
"general.name": "llama",
|
||||||
|
"llama.context_length": (
|
||||||
|
mx.array(config["max_position_embeddings"], dtype=mx.uint32)
|
||||||
|
if config.get("max_position_embeddings") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.embedding_length": (
|
||||||
|
mx.array(config["hidden_size"], dtype=mx.uint32)
|
||||||
|
if config.get("hidden_size") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.block_count": (
|
||||||
|
mx.array(config["num_hidden_layers"], dtype=mx.uint32)
|
||||||
|
if config.get("num_hidden_layers") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.feed_forward_length": (
|
||||||
|
mx.array(config["intermediate_size"], dtype=mx.uint32)
|
||||||
|
if config.get("intermediate_size") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.rope.dimension_count": (
|
||||||
|
mx.array(
|
||||||
|
config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32
|
||||||
|
)
|
||||||
|
if config.get("hidden_size") is not None
|
||||||
|
and config.get("num_attention_heads") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.attention.head_count": (
|
||||||
|
mx.array(config["num_attention_heads"], dtype=mx.uint32)
|
||||||
|
if config.get("num_attention_heads") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.attention.head_count_kv": (
|
||||||
|
mx.array(
|
||||||
|
config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||||
|
dtype=mx.uint32,
|
||||||
|
)
|
||||||
|
if config.get("num_attention_heads") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.expert_count": (
|
||||||
|
mx.array(config.get("num_local_experts", None), dtype=mx.uint32)
|
||||||
|
if config.get("num_local_experts") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.expert_used_count": (
|
||||||
|
mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32)
|
||||||
|
if config.get("num_experts_per_tok") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.attention.layer_norm_rms_epsilon": (
|
||||||
|
mx.array(config.get("rms_norm_eps", 1e-05))
|
||||||
|
if config.get("rms_norm_eps") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"llama.rope.freq_base": (
|
||||||
|
mx.array(config.get("rope_theta", 10000), dtype=mx.float32)
|
||||||
|
if config.get("rope_theta") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
rope_scaling = config.get("rope_scaling")
|
||||||
|
if rope_scaling is not None and (typ := rope_scaling.get("type")):
|
||||||
|
rope_factor = rope_scaling.get("factor")
|
||||||
|
f_rope_scale = rope_factor
|
||||||
|
if typ == "linear":
|
||||||
|
rope_scaling_type = "linear"
|
||||||
|
metadata["llama.rope.scaling.type"] = rope_scaling_type
|
||||||
|
metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale)
|
||||||
|
|
||||||
|
metadata["general.file_type"] = mx.array(
|
||||||
|
GGMLFileType.GGML_TYPE_F16.value,
|
||||||
|
dtype=mx.uint32,
|
||||||
|
)
|
||||||
|
metadata["general.quantization_version"] = mx.array(
|
||||||
|
GGMLFileType.GGML_TYPE_F16.value,
|
||||||
|
dtype=mx.uint32,
|
||||||
|
)
|
||||||
|
metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1]
|
||||||
|
metadata["general.architecture"] = "llama"
|
||||||
|
metadata["general.alignment"] = mx.array(32, dtype=mx.uint32)
|
||||||
|
|
||||||
|
# add metadata for vocab
|
||||||
|
metadata["tokenizer.ggml.model"] = "llama"
|
||||||
|
tokens = []
|
||||||
|
scores = []
|
||||||
|
toktypes = []
|
||||||
|
for text, score, toktype in vocab.all_tokens():
|
||||||
|
tokens.append(text)
|
||||||
|
scores.append(score)
|
||||||
|
toktypes.append(toktype.value)
|
||||||
|
assert len(tokens) == vocab.vocab_size
|
||||||
|
metadata["tokenizer.ggml.tokens"] = tokens
|
||||||
|
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
|
||||||
|
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
|
||||||
|
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
|
||||||
|
vocab.tokenizer.bos_token_id, dtype=mx.uint32
|
||||||
|
)
|
||||||
|
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
|
||||||
|
vocab.tokenizer.eos_token_id, dtype=mx.uint32
|
||||||
|
)
|
||||||
|
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
|
||||||
|
vocab.tokenizer.unk_token_id, dtype=mx.uint32
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_gguf(
|
||||||
|
model_path: Union[str, Path],
|
||||||
|
weights: dict,
|
||||||
|
config: dict,
|
||||||
|
output_file_path: str,
|
||||||
|
):
|
||||||
|
if isinstance(model_path, str):
|
||||||
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
quantization = config.get("quantization", None)
|
||||||
|
if quantization:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Conversion of quantized models is not yet supported."
|
||||||
|
)
|
||||||
|
print("Converting to GGUF format")
|
||||||
|
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention
|
||||||
|
weights = {
|
||||||
|
k: (
|
||||||
|
permute_weights(
|
||||||
|
v, config["num_attention_heads"], config["num_attention_heads"]
|
||||||
|
)
|
||||||
|
if "self_attn.q_proj.weight" in k
|
||||||
|
else (
|
||||||
|
permute_weights(
|
||||||
|
v, config["num_attention_heads"], config["num_key_value_heads"]
|
||||||
|
)
|
||||||
|
if "self_attn.k_proj.weight" in k
|
||||||
|
else v
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for k, v in weights.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# rename weights for gguf format
|
||||||
|
weights = {translate_weight_names(k): v for k, v in weights.items()}
|
||||||
|
|
||||||
|
if not (model_path / "tokenizer.json").exists():
|
||||||
|
raise ValueError("Tokenizer json not found")
|
||||||
|
|
||||||
|
vocab = HfVocab.load(model_path)
|
||||||
|
metadata = prepare_metadata(config, vocab)
|
||||||
|
|
||||||
|
weights = {
|
||||||
|
k: (
|
||||||
|
v.astype(mx.float32).astype(mx.float16)
|
||||||
|
if v.dtype == mx.bfloat16
|
||||||
|
else v.astype(mx.float32) if "norm" in k else v
|
||||||
|
)
|
||||||
|
for k, v in weights.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
output_file_path = output_file_path
|
||||||
|
mx.save_gguf(output_file_path, weights, metadata)
|
||||||
|
print(f"Converted GGUF model saved as: {output_file_path}")
|
||||||
278
llms/mlx_lm/lora.py
Normal file
278
llms/mlx_lm/lora.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
|
from .tuner.datasets import load_dataset
|
||||||
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
|
from .tuner.utils import (
|
||||||
|
apply_lora_layers,
|
||||||
|
build_schedule,
|
||||||
|
linear_to_lora_layers,
|
||||||
|
print_trainable_parameters,
|
||||||
|
)
|
||||||
|
from .utils import load, save_config
|
||||||
|
|
||||||
|
yaml_loader = yaml.SafeLoader
|
||||||
|
yaml_loader.add_implicit_resolver(
|
||||||
|
"tag:yaml.org,2002:float",
|
||||||
|
re.compile(
|
||||||
|
"""^(?:
|
||||||
|
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||||
|
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||||
|
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||||
|
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||||
|
|[-+]?\\.(?:inf|Inf|INF)
|
||||||
|
|\\.(?:nan|NaN|NAN))$""",
|
||||||
|
re.X,
|
||||||
|
),
|
||||||
|
list("-+0123456789."),
|
||||||
|
)
|
||||||
|
|
||||||
|
CONFIG_DEFAULTS = {
|
||||||
|
"model": "mlx_model",
|
||||||
|
"train": False,
|
||||||
|
"data": "data/",
|
||||||
|
"seed": 0,
|
||||||
|
"lora_layers": 16,
|
||||||
|
"batch_size": 4,
|
||||||
|
"iters": 1000,
|
||||||
|
"val_batches": 25,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"steps_per_report": 10,
|
||||||
|
"steps_per_eval": 200,
|
||||||
|
"resume_adapter_file": None,
|
||||||
|
"adapter_path": "adapters",
|
||||||
|
"save_every": 100,
|
||||||
|
"test": False,
|
||||||
|
"test_batches": 500,
|
||||||
|
"max_seq_length": 2048,
|
||||||
|
"lr_schedule": None,
|
||||||
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
|
"use_dora": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training args
|
||||||
|
parser.add_argument(
|
||||||
|
"--train",
|
||||||
|
action="store_true",
|
||||||
|
help="Do training",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data",
|
||||||
|
type=str,
|
||||||
|
help="Directory with {train, valid, test}.jsonl files",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-layers",
|
||||||
|
type=int,
|
||||||
|
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
|
||||||
|
parser.add_argument("--iters", type=int, help="Iterations to train for.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--val-batches",
|
||||||
|
type=int,
|
||||||
|
help="Number of validation batches, -1 uses the entire validation set.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps-per-report",
|
||||||
|
type=int,
|
||||||
|
help="Number of training steps between loss reporting.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps-per-eval",
|
||||||
|
type=int,
|
||||||
|
help="Number of training steps between validations.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume-adapter-file",
|
||||||
|
type=str,
|
||||||
|
help="Load path to resume training with the given adapters.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-path",
|
||||||
|
type=str,
|
||||||
|
help="Save/load path for the adapters.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every",
|
||||||
|
type=int,
|
||||||
|
help="Save the model every N iterations.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test",
|
||||||
|
action="store_true",
|
||||||
|
help="Evaluate on the test set after training",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-batches",
|
||||||
|
type=int,
|
||||||
|
help="Number of test set batches, -1 uses the entire test set.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-seq-length",
|
||||||
|
type=int,
|
||||||
|
help="Maximum sequence length.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--config",
|
||||||
|
default=None,
|
||||||
|
help="A YAML configuration file with the training options",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad-checkpoint",
|
||||||
|
action="store_true",
|
||||||
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
args,
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
train_set,
|
||||||
|
valid_set,
|
||||||
|
training_callback: TrainingCallback = None,
|
||||||
|
):
|
||||||
|
# Freeze all layers
|
||||||
|
model.freeze()
|
||||||
|
|
||||||
|
# Convert linear layers to lora layers and unfreeze in the process
|
||||||
|
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
||||||
|
|
||||||
|
# Resume training the given adapters.
|
||||||
|
if args.resume_adapter_file is not None:
|
||||||
|
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||||
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
|
||||||
|
print_trainable_parameters(model)
|
||||||
|
|
||||||
|
adapter_path = Path(args.adapter_path)
|
||||||
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
adapter_file = adapter_path / "adapters.safetensors"
|
||||||
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
|
||||||
|
# init training args
|
||||||
|
training_args = TrainingArgs(
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
iters=args.iters,
|
||||||
|
val_batches=args.val_batches,
|
||||||
|
steps_per_report=args.steps_per_report,
|
||||||
|
steps_per_eval=args.steps_per_eval,
|
||||||
|
steps_per_save=args.save_every,
|
||||||
|
adapter_file=adapter_file,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
opt = optim.Adam(
|
||||||
|
learning_rate=(
|
||||||
|
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Train model
|
||||||
|
train(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
args=training_args,
|
||||||
|
optimizer=opt,
|
||||||
|
train_dataset=train_set,
|
||||||
|
val_dataset=valid_set,
|
||||||
|
training_callback=training_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
test_loss = evaluate(
|
||||||
|
model=model,
|
||||||
|
dataset=test_set,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_batches=args.test_batches,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_ppl = math.exp(test_loss)
|
||||||
|
|
||||||
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||||
|
|
||||||
|
|
||||||
|
def run(args, training_callback: TrainingCallback = None):
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
print("Loading pretrained model")
|
||||||
|
model, tokenizer = load(args.model)
|
||||||
|
|
||||||
|
print("Loading datasets")
|
||||||
|
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||||
|
|
||||||
|
if args.test and not args.train:
|
||||||
|
# Allow testing without LoRA layers by providing empty path
|
||||||
|
if args.adapter_path != "":
|
||||||
|
apply_lora_layers(model, args.adapter_path)
|
||||||
|
|
||||||
|
elif args.train:
|
||||||
|
print("Training")
|
||||||
|
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
|
||||||
|
else:
|
||||||
|
raise ValueError("Must provide at least one of --train or --test")
|
||||||
|
|
||||||
|
if args.test:
|
||||||
|
print("Testing")
|
||||||
|
evaluate_model(args, model, tokenizer, test_set)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = args.config
|
||||||
|
args = vars(args)
|
||||||
|
if config:
|
||||||
|
print("Loading configuration file", config)
|
||||||
|
with open(config, "r") as file:
|
||||||
|
config = yaml.load(file, yaml_loader)
|
||||||
|
# Prefer parameters from command-line arguments
|
||||||
|
for k, v in config.items():
|
||||||
|
if args.get(k, None) is None:
|
||||||
|
args[k] = v
|
||||||
|
|
||||||
|
# Update defaults for unspecified parameters
|
||||||
|
for k, v in CONFIG_DEFAULTS.items():
|
||||||
|
if args.get(k, None) is None:
|
||||||
|
args[k] = v
|
||||||
|
run(types.SimpleNamespace(**args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
121
llms/mlx_lm/manage.py
Normal file
121
llms/mlx_lm/manage.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from huggingface_hub import scan_cache_dir
|
||||||
|
from transformers.commands.user import tabulate
|
||||||
|
|
||||||
|
|
||||||
|
def ask_for_confirmation(message: str) -> bool:
|
||||||
|
y = ("y", "yes", "1")
|
||||||
|
n = ("n", "no", "0")
|
||||||
|
all_values = y + n + ("",)
|
||||||
|
full_message = f"{message} (Y/n) "
|
||||||
|
while True:
|
||||||
|
answer = input(full_message).lower()
|
||||||
|
if answer == "":
|
||||||
|
return False
|
||||||
|
if answer in y:
|
||||||
|
return True
|
||||||
|
if answer in n:
|
||||||
|
return False
|
||||||
|
print(f"Invalid input. Must be one of {all_values}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MLX Model Cache.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--scan",
|
||||||
|
action="store_true",
|
||||||
|
help="Scan Hugging Face cache for mlx models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--delete",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete models matching the given pattern.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pattern",
|
||||||
|
type=str,
|
||||||
|
help="Model repos contain the pattern.",
|
||||||
|
default="mlx",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.scan:
|
||||||
|
print(
|
||||||
|
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
|
||||||
|
)
|
||||||
|
hf_cache_info = scan_cache_dir()
|
||||||
|
print(
|
||||||
|
tabulate(
|
||||||
|
rows=[
|
||||||
|
[
|
||||||
|
repo.repo_id,
|
||||||
|
repo.repo_type,
|
||||||
|
"{:>12}".format(repo.size_on_disk_str),
|
||||||
|
repo.nb_files,
|
||||||
|
repo.last_accessed_str,
|
||||||
|
repo.last_modified_str,
|
||||||
|
str(repo.repo_path),
|
||||||
|
]
|
||||||
|
for repo in sorted(
|
||||||
|
hf_cache_info.repos, key=lambda repo: repo.repo_path
|
||||||
|
)
|
||||||
|
if args.pattern in repo.repo_id
|
||||||
|
],
|
||||||
|
headers=[
|
||||||
|
"REPO ID",
|
||||||
|
"REPO TYPE",
|
||||||
|
"SIZE ON DISK",
|
||||||
|
"NB FILES",
|
||||||
|
"LAST_ACCESSED",
|
||||||
|
"LAST_MODIFIED",
|
||||||
|
"LOCAL PATH",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.delete:
|
||||||
|
print(f'Deleting models matching pattern "{args.pattern}"')
|
||||||
|
hf_cache_info = scan_cache_dir()
|
||||||
|
|
||||||
|
repos = [
|
||||||
|
repo
|
||||||
|
for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
|
||||||
|
if args.pattern in repo.repo_id
|
||||||
|
]
|
||||||
|
if repos:
|
||||||
|
print(
|
||||||
|
tabulate(
|
||||||
|
rows=[
|
||||||
|
[
|
||||||
|
repo.repo_id,
|
||||||
|
str(repo.repo_path),
|
||||||
|
]
|
||||||
|
for repo in repos
|
||||||
|
],
|
||||||
|
headers=[
|
||||||
|
"REPO ID",
|
||||||
|
"LOCAL PATH",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
confirmed = ask_for_confirmation(f"Confirm deletion ?")
|
||||||
|
if confirmed:
|
||||||
|
for model_info in repos:
|
||||||
|
for revision in sorted(
|
||||||
|
model_info.revisions, key=lambda revision: revision.commit_hash
|
||||||
|
):
|
||||||
|
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
|
||||||
|
strategy.execute()
|
||||||
|
print("Model(s) deleted.")
|
||||||
|
else:
|
||||||
|
print("Deletion is cancelled. Do nothing.")
|
||||||
|
else:
|
||||||
|
print(f"No models found.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
172
llms/mlx_lm/merge.py
Normal file
172
llms/mlx_lm/merge.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
from mlx.utils import tree_flatten, tree_map
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
fetch_from_hub,
|
||||||
|
get_model_path,
|
||||||
|
save_config,
|
||||||
|
save_weights,
|
||||||
|
upload_to_hub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_parser() -> argparse.ArgumentParser:
|
||||||
|
"""
|
||||||
|
Configures and returns the argument parser for the script.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.ArgumentParser: Configured argument parser.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="Merge multiple models.")
|
||||||
|
|
||||||
|
parser.add_argument("--config", type=str, help="Path to the YAML config.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx-path",
|
||||||
|
type=str,
|
||||||
|
default="mlx_merged_model",
|
||||||
|
help="Path to save the MLX model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload-repo",
|
||||||
|
help="The Hugging Face repo to upload the model to.",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def slerp(t, w1, w2, eps=1e-5):
|
||||||
|
"""
|
||||||
|
Spherical linear interpolation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (float): Interpolation weight in [0.0, 1.0]
|
||||||
|
w1 (mx.array): First input
|
||||||
|
w2 (mx.array): Second input
|
||||||
|
eps (float): Constant for numerical stability
|
||||||
|
Returns:
|
||||||
|
mx.array: Interpolated result
|
||||||
|
"""
|
||||||
|
t = float(t)
|
||||||
|
if t == 0:
|
||||||
|
return w1
|
||||||
|
elif t == 1:
|
||||||
|
return w2
|
||||||
|
# Normalize
|
||||||
|
v1 = w1 / mx.linalg.norm(w1)
|
||||||
|
v2 = w2 / mx.linalg.norm(w2)
|
||||||
|
# Angle
|
||||||
|
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
|
||||||
|
theta = mx.arccos(dot)
|
||||||
|
sin_theta = mx.sin(theta + eps)
|
||||||
|
s1 = mx.sin(theta * (1 - t)) / sin_theta
|
||||||
|
s2 = mx.sin(theta * t) / sin_theta
|
||||||
|
return s1 * w1 + s2 * w2
|
||||||
|
|
||||||
|
|
||||||
|
def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
||||||
|
method = config.get("method", None)
|
||||||
|
if method != "slerp":
|
||||||
|
raise ValueError(f"Merge method {method} not supported")
|
||||||
|
|
||||||
|
num_layers = len(model.layers)
|
||||||
|
|
||||||
|
def unpack_values(vals):
|
||||||
|
if isinstance(vals, (int, float)):
|
||||||
|
return np.full(num_layers, vals)
|
||||||
|
bins = len(vals) - 1
|
||||||
|
sizes = [num_layers // bins] * bins
|
||||||
|
sizes[-1] = num_layers - sum(sizes[:-1])
|
||||||
|
return np.concatenate(
|
||||||
|
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
|
||||||
|
)
|
||||||
|
|
||||||
|
param_list = config["parameters"]["t"]
|
||||||
|
params = {}
|
||||||
|
filter_keys = set()
|
||||||
|
for pl in param_list[:-1]:
|
||||||
|
params[pl["filter"]] = unpack_values(pl["value"])
|
||||||
|
filter_keys.add(pl["filter"])
|
||||||
|
default = unpack_values(param_list[-1]["value"])
|
||||||
|
|
||||||
|
for e in range(num_layers):
|
||||||
|
bl = base_model.layers[e]
|
||||||
|
l = model.layers[e]
|
||||||
|
base_weights = bl.parameters()
|
||||||
|
weights = l.parameters()
|
||||||
|
for k, w1 in base_weights.items():
|
||||||
|
w2 = weights[k]
|
||||||
|
t = params.get(k, default)[e]
|
||||||
|
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
|
||||||
|
base_model.update(base_weights)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
config: str,
|
||||||
|
mlx_path: str = "mlx_model",
|
||||||
|
upload_repo: Optional[str] = None,
|
||||||
|
):
|
||||||
|
with open(config, "r") as fid:
|
||||||
|
merge_conf = yaml.safe_load(fid)
|
||||||
|
print("[INFO] Loading")
|
||||||
|
|
||||||
|
model_paths = merge_conf.get("models", [])
|
||||||
|
if len(model_paths) < 2:
|
||||||
|
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
|
||||||
|
|
||||||
|
# Load all models
|
||||||
|
base_hf_path = model_paths[0]
|
||||||
|
base_path = get_model_path(base_hf_path)
|
||||||
|
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||||
|
models = []
|
||||||
|
for mp in model_paths[1:]:
|
||||||
|
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||||
|
base_type = base_config["model_type"]
|
||||||
|
model_type = model_config["model_type"]
|
||||||
|
if base_type != model_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"Can only merge models of the same type,"
|
||||||
|
f" but got {base_type} and {model_type}."
|
||||||
|
)
|
||||||
|
models.append(model)
|
||||||
|
|
||||||
|
# Merge models into base model
|
||||||
|
for m in models:
|
||||||
|
merge_models(base_model, m, merge_conf)
|
||||||
|
|
||||||
|
# Save base model
|
||||||
|
mlx_path = Path(mlx_path)
|
||||||
|
weights = dict(tree_flatten(base_model.parameters()))
|
||||||
|
del models, base_model
|
||||||
|
save_weights(mlx_path, weights, donate_weights=True)
|
||||||
|
py_files = glob.glob(str(base_path / "*.py"))
|
||||||
|
for file in py_files:
|
||||||
|
shutil.copy(file, mlx_path)
|
||||||
|
|
||||||
|
tokenizer.save_pretrained(mlx_path)
|
||||||
|
|
||||||
|
save_config(config, config_path=mlx_path / "config.json")
|
||||||
|
|
||||||
|
if upload_repo is not None:
|
||||||
|
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = configure_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
merge(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
llms/mlx_lm/models/__init__.py
Normal file
0
llms/mlx_lm/models/__init__.py
Normal file
56
llms/mlx_lm/models/base.py
Normal file
56
llms/mlx_lm/models/base.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||||
|
rinds = mx.arange(offset + N)
|
||||||
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||||
|
mask = linds[:, None] < rinds[None]
|
||||||
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache:
|
||||||
|
|
||||||
|
def __init__(self, head_dim, n_kv_heads):
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.keys = None
|
||||||
|
self.values = None
|
||||||
|
self.offset = 0
|
||||||
|
self.step = 256
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values):
|
||||||
|
prev = self.offset
|
||||||
|
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||||
|
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||||
|
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
|
||||||
|
new_k = mx.zeros(shape, keys.dtype)
|
||||||
|
new_v = mx.zeros(shape, values.dtype)
|
||||||
|
if self.keys is not None:
|
||||||
|
if prev % self.step != 0:
|
||||||
|
self.keys = self.keys[..., :prev, :]
|
||||||
|
self.values = self.values[..., :prev, :]
|
||||||
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||||
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||||
|
else:
|
||||||
|
self.keys, self.values = new_k, new_v
|
||||||
|
|
||||||
|
self.offset += keys.shape[2]
|
||||||
|
self.keys[..., prev : self.offset, :] = keys
|
||||||
|
self.values[..., prev : self.offset, :] = values
|
||||||
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelArgs:
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in params.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
201
llms/mlx_lm/models/cohere.py
Normal file
201
llms/mlx_lm/models/cohere.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int = 8192
|
||||||
|
num_hidden_layers: int = 40
|
||||||
|
intermediate_size: int = 22528
|
||||||
|
num_attention_heads: int = 64
|
||||||
|
num_key_value_heads: int = 64
|
||||||
|
rope_theta: float = 8000000.0
|
||||||
|
vocab_size: int = 256000
|
||||||
|
layer_norm_eps: float = 1e-05
|
||||||
|
logit_scale: float = 0.0625
|
||||||
|
attention_bias: bool = False
|
||||||
|
layer_norm_bias: bool = False
|
||||||
|
use_qk_norm: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d1, d2, eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.zeros((d1, d2))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
head_dim = args.hidden_size // args.num_attention_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
attetion_bias = args.attention_bias
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||||
|
|
||||||
|
self.use_qk_norm = args.use_qk_norm
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
|
||||||
|
self.k_norm = LayerNorm2D(
|
||||||
|
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1)
|
||||||
|
if self.use_qk_norm:
|
||||||
|
queries = self.q_norm(queries)
|
||||||
|
keys = self.k_norm(keys)
|
||||||
|
|
||||||
|
queries = queries.transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.n_heads = args.num_attention_heads
|
||||||
|
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
self.input_layernorm = nn.LayerNorm(
|
||||||
|
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||||
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
h = self.input_layernorm(x)
|
||||||
|
attn_h = self.self_attn(h, mask, cache)
|
||||||
|
ff_h = self.mlp(h)
|
||||||
|
return attn_h + ff_h + x
|
||||||
|
|
||||||
|
|
||||||
|
class CohereModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.LayerNorm(
|
||||||
|
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = CohereModel(args)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
out = out * self.model.args.logit_scale
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
261
llms/mlx_lm/models/dbrx.py
Normal file
261
llms/mlx_lm/models/dbrx.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
vocab_size: int
|
||||||
|
d_model: int
|
||||||
|
ffn_config: dict
|
||||||
|
attn_config: dict
|
||||||
|
n_layers: int
|
||||||
|
n_heads: int
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = args.n_heads
|
||||||
|
self.d_model = args.d_model
|
||||||
|
self.head_dim = args.d_model // args.n_heads
|
||||||
|
self.num_key_value_heads = args.attn_config["kv_n_heads"]
|
||||||
|
self.clip_qkv = args.attn_config["clip_qkv"]
|
||||||
|
self.rope_theta = args.attn_config["rope_theta"]
|
||||||
|
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.Wqkv = nn.Linear(
|
||||||
|
args.d_model,
|
||||||
|
(self.num_key_value_heads * 2 + self.num_heads) * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.out_proj = nn.Linear(args.d_model, args.d_model, bias=False)
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
self.head_dim,
|
||||||
|
traditional=False,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
|
qkv = self.Wqkv(x)
|
||||||
|
qkv = mx.clip(qkv, a_min=-self.clip_qkv, a_max=self.clip_qkv)
|
||||||
|
splits = [self.d_model, self.d_model + self.head_dim * self.num_key_value_heads]
|
||||||
|
queries, keys, values = mx.split(qkv, splits, axis=-1)
|
||||||
|
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.out_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class NormAttnNorm(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_1 = nn.LayerNorm(args.d_model, bias=False)
|
||||||
|
self.norm_2 = nn.LayerNorm(args.d_model, bias=False)
|
||||||
|
self.attn = Attention(args)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
||||||
|
x = h + x
|
||||||
|
return x, self.norm_2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, d_model: int, ffn_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.v1 = nn.Linear(d_model, ffn_dim, bias=False)
|
||||||
|
self.w1 = nn.Linear(d_model, ffn_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(ffn_dim, d_model, bias=False)
|
||||||
|
self.act_fn = nn.silu
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
current_hidden_states = self.act_fn(self.w1(x)) * self.v1(x)
|
||||||
|
current_hidden_states = self.w2(current_hidden_states)
|
||||||
|
return current_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Router(nn.Module):
|
||||||
|
def __init__(self, d_model: int, num_experts: int):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = nn.Linear(d_model, num_experts, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
return self.layer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMoeBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = args.d_model
|
||||||
|
self.ffn_dim = args.ffn_config["ffn_hidden_size"]
|
||||||
|
self.num_experts = args.ffn_config["moe_num_experts"]
|
||||||
|
self.num_experts_per_tok = args.ffn_config["moe_top_k"]
|
||||||
|
|
||||||
|
self.router = Router(self.d_model, self.num_experts)
|
||||||
|
self.experts = [
|
||||||
|
MLP(self.d_model, self.ffn_dim) for _ in range(self.num_experts)
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
ne = self.num_experts_per_tok
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
|
||||||
|
gates = self.router(x)
|
||||||
|
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
|
||||||
|
|
||||||
|
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
|
||||||
|
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||||
|
scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True)
|
||||||
|
scores = scores.astype(x.dtype)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
inds = np.array(inds)
|
||||||
|
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||||
|
for e, expert in enumerate(self.experts):
|
||||||
|
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||||
|
if idx1.size == 0:
|
||||||
|
continue
|
||||||
|
y[idx1, idx2] = expert(x[idx1])
|
||||||
|
|
||||||
|
y = (y * scores[:, :, None]).sum(axis=1)
|
||||||
|
else:
|
||||||
|
y = []
|
||||||
|
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||||
|
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
|
||||||
|
yt = (yt * st).sum(axis=-1)
|
||||||
|
y.append(yt)
|
||||||
|
y = mx.stack(y, axis=0)
|
||||||
|
|
||||||
|
return y.reshape(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.ffn = SparseMoeBlock(args)
|
||||||
|
self.norm_attn_norm = NormAttnNorm(args)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r, h = self.norm_attn_norm(x, mask, cache)
|
||||||
|
out = self.ffn(h) + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DBRX(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.wte = nn.Embedding(args.vocab_size, args.d_model)
|
||||||
|
self.blocks = [DecoderLayer(args=args) for _ in range(args.n_layers)]
|
||||||
|
self.norm_f = nn.LayerNorm(args.d_model, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.wte(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
T = h.shape[1]
|
||||||
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.blocks)
|
||||||
|
|
||||||
|
for layer, c in zip(self.blocks, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm_f(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.transformer = DBRX(args)
|
||||||
|
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.transformer(inputs, cache)
|
||||||
|
return self.lm_head(out)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.transformer.blocks
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
# Split experts into sub matrices
|
||||||
|
num_experts = self.args.ffn_config["moe_num_experts"]
|
||||||
|
dim = self.args.ffn_config["ffn_hidden_size"]
|
||||||
|
|
||||||
|
pattern = "experts.mlp"
|
||||||
|
new_weights = {k: v for k, v in weights.items() if pattern not in k}
|
||||||
|
for k, v in weights.items():
|
||||||
|
if pattern in k:
|
||||||
|
experts = [
|
||||||
|
(k.replace(".mlp", f".{e}") + ".weight", sv)
|
||||||
|
for e, sv in enumerate(mx.split(v, num_experts, axis=0))
|
||||||
|
]
|
||||||
|
if k.endswith("w2"):
|
||||||
|
experts = [(s, sv.T) for s, sv in experts]
|
||||||
|
new_weights.update(experts)
|
||||||
|
return new_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.d_model // self.args.n_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.attn_config["kv_n_heads"]
|
||||||
184
llms/mlx_lm/models/gemma.py
Normal file
184
llms/mlx_lm/models/gemma.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
head_dim: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
self.head_dim = head_dim = args.head_dim
|
||||||
|
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
h = h * (self.args.hidden_size**0.5)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = GemmaModel(args)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.head_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
207
llms/mlx_lm/models/gpt2.py
Normal file
207
llms/mlx_lm/models/gpt2.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, create_additive_causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
n_ctx: int
|
||||||
|
n_embd: int
|
||||||
|
n_head: int
|
||||||
|
n_layer: int
|
||||||
|
n_positions: int
|
||||||
|
layer_norm_epsilon: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.n_head
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
|
||||||
|
|
||||||
|
self.n_embd = args.n_embd
|
||||||
|
self.n_head = args.n_head
|
||||||
|
self.head_dim = self.n_embd // self.n_head
|
||||||
|
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
|
||||||
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
qkv = self.c_attn(x)
|
||||||
|
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.c_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.n_embd = args.n_embd
|
||||||
|
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
|
||||||
|
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.n_head = args.n_head
|
||||||
|
self.n_embd = args.n_embd
|
||||||
|
self.layer_norm_epsilon = args.layer_norm_epsilon
|
||||||
|
self.attn = Attention(args)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
self.ln_1 = nn.LayerNorm(
|
||||||
|
self.n_embd,
|
||||||
|
eps=self.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.attn(self.ln_1(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.ln_2(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_embd = args.n_embd
|
||||||
|
self.n_positions = args.n_positions
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.n_layer = args.n_layer
|
||||||
|
self.layer_norm_epsilon = args.layer_norm_epsilon
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
|
||||||
|
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
|
||||||
|
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
|
||||||
|
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
_, L = inputs.shape
|
||||||
|
|
||||||
|
hidden_states = self.wte(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if hidden_states.shape[1] > 1:
|
||||||
|
|
||||||
|
position_ids = mx.array(np.arange(L))
|
||||||
|
hidden_states += self.wpe(position_ids)
|
||||||
|
|
||||||
|
mask = create_additive_causal_mask(
|
||||||
|
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||||
|
)
|
||||||
|
mask = mask.astype(hidden_states.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
|
for layer, c in zip(self.h, cache):
|
||||||
|
hidden_states = layer(hidden_states, mask, cache=c)
|
||||||
|
|
||||||
|
return self.ln_f(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = GPT2Model(args)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
out = self.model.wte.as_linear(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
new_weights = {}
|
||||||
|
for i in range(self.args.n_layer):
|
||||||
|
if f"h.{i}.attn.bias" in weights:
|
||||||
|
del weights[f"h.{i}.attn.bias"]
|
||||||
|
if f"h.{i}.attn.c_attn.weight" in weights:
|
||||||
|
weights[f"h.{i}.attn.c_attn.weight"] = weights[
|
||||||
|
f"h.{i}.attn.c_attn.weight"
|
||||||
|
].transpose(1, 0)
|
||||||
|
if f"h.{i}.attn.c_proj.weight" in weights:
|
||||||
|
weights[f"h.{i}.attn.c_proj.weight"] = weights[
|
||||||
|
f"h.{i}.attn.c_proj.weight"
|
||||||
|
].transpose(1, 0)
|
||||||
|
if f"h.{i}.mlp.c_fc.weight" in weights:
|
||||||
|
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
|
||||||
|
f"h.{i}.mlp.c_fc.weight"
|
||||||
|
].transpose(1, 0)
|
||||||
|
if f"h.{i}.mlp.c_proj.weight" in weights:
|
||||||
|
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
|
||||||
|
f"h.{i}.mlp.c_proj.weight"
|
||||||
|
].transpose(1, 0)
|
||||||
|
for weight in weights:
|
||||||
|
if not weight.startswith("model."):
|
||||||
|
new_weights[f"model.{weight}"] = weights[weight]
|
||||||
|
else:
|
||||||
|
new_weights[weight] = weights[weight]
|
||||||
|
return new_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.h
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.n_embd // self.args.n_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
195
llms/mlx_lm/models/gpt_bigcode.py
Normal file
195
llms/mlx_lm/models/gpt_bigcode.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, create_additive_causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
n_embd: int
|
||||||
|
n_layer: int
|
||||||
|
n_inner: int
|
||||||
|
n_head: int
|
||||||
|
n_positions: int
|
||||||
|
layer_norm_epsilon: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
multi_query: bool = True
|
||||||
|
attention_bias: bool = True
|
||||||
|
mlp_bias: bool = True
|
||||||
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = 1 if self.multi_query else self.n_head
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim = args.n_embd
|
||||||
|
self.n_heads = n_heads = args.n_head
|
||||||
|
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
|
||||||
|
|
||||||
|
self.head_dim = head_dim = dim // n_heads
|
||||||
|
|
||||||
|
self.kv_dim = n_kv_heads * head_dim
|
||||||
|
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
if hasattr(args, "attention_bias"):
|
||||||
|
attention_bias = args.attention_bias
|
||||||
|
else:
|
||||||
|
attention_bias = False
|
||||||
|
|
||||||
|
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
|
||||||
|
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
qkv = self.c_attn(x)
|
||||||
|
queries, keys, values = mx.split(
|
||||||
|
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.c_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.n_embd
|
||||||
|
hidden_dim = args.n_inner
|
||||||
|
if hasattr(args, "mlp_bias"):
|
||||||
|
mlp_bias = args.mlp_bias
|
||||||
|
else:
|
||||||
|
mlp_bias = False
|
||||||
|
|
||||||
|
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
|
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_head = args.n_head
|
||||||
|
self.n_embd = args.n_embd
|
||||||
|
self.attn = Attention(args)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||||
|
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.attn(self.ln_1(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.ln_2(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GPTBigCodeModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
|
||||||
|
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
|
||||||
|
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
|
||||||
|
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
B, L = inputs.shape
|
||||||
|
|
||||||
|
hidden_states = self.wte(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if hidden_states.shape[1] > 1:
|
||||||
|
|
||||||
|
position_ids = mx.array(np.arange(L))
|
||||||
|
hidden_states += self.wpe(position_ids)
|
||||||
|
|
||||||
|
mask = create_additive_causal_mask(
|
||||||
|
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||||
|
)
|
||||||
|
mask = mask.astype(hidden_states.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
|
for layer, c in zip(self.h, cache):
|
||||||
|
hidden_states = layer(hidden_states, mask, cache=c)
|
||||||
|
|
||||||
|
return self.ln_f(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.transformer = GPTBigCodeModel(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.transformer(inputs, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.transformer.wte.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.transformer.h
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.n_embd // self.args.n_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
198
llms/mlx_lm/models/internlm2.py
Normal file
198
llms/mlx_lm/models/internlm2.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
bias: bool = True
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
if self.rope_scaling:
|
||||||
|
required_keys = {"factor", "type"}
|
||||||
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
|
if self.rope_scaling["type"] != "linear":
|
||||||
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.wqkv = nn.Linear(
|
||||||
|
dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=args.bias
|
||||||
|
)
|
||||||
|
self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias)
|
||||||
|
|
||||||
|
rope_scale = (
|
||||||
|
1 / args.rope_scaling["factor"]
|
||||||
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
qkv_states = self.wqkv(x)
|
||||||
|
qkv_states = qkv_states.reshape(B, L, -1, 2 + self.n_kv_groups, self.head_dim)
|
||||||
|
|
||||||
|
queries = qkv_states[..., : self.n_kv_groups, :]
|
||||||
|
queries = queries.reshape(B, L, -1, self.head_dim)
|
||||||
|
keys = qkv_states[..., -2, :]
|
||||||
|
values = qkv_states[..., -1, :]
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.wo(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Attention(args)
|
||||||
|
self.feed_forward = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
self.attention_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.attention(self.attention_norm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.feed_forward(self.ffn_norm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class InternLM2Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
assert args.vocab_size > 0
|
||||||
|
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.tok_embeddings(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = InternLM2Model(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.tok_embeddings.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.output(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
220
llms/mlx_lm/models/llama.py
Normal file
220
llms/mlx_lm/models/llama.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, KVCache, create_additive_causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: Optional[int] = None
|
||||||
|
attention_bias: bool = False
|
||||||
|
mlp_bias: bool = False
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
if self.rope_scaling:
|
||||||
|
required_keys = {"factor", "type"}
|
||||||
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
|
if self.rope_scaling["type"] != "linear":
|
||||||
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
if hasattr(args, "attention_bias"):
|
||||||
|
attention_bias = args.attention_bias
|
||||||
|
else:
|
||||||
|
attention_bias = False
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
|
rope_scale = (
|
||||||
|
1 / args.rope_scaling["factor"]
|
||||||
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[KVCache] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
hidden_dim = args.intermediate_size
|
||||||
|
if hasattr(args, "mlp_bias"):
|
||||||
|
mlp_bias = args.mlp_bias
|
||||||
|
else:
|
||||||
|
mlp_bias = False
|
||||||
|
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[KVCache] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = create_additive_causal_mask(
|
||||||
|
h.shape[1], cache[0].offset if cache is not None else 0
|
||||||
|
)
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = LlamaModel(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
# Remove unused precomputed rotary freqs
|
||||||
|
return {
|
||||||
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
216
llms/mlx_lm/models/minicpm.py
Normal file
216
llms/mlx_lm/models/minicpm.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
dim_model_base: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
scale_depth: float
|
||||||
|
scale_emb: float
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[str, float]]] = None
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.num_heads = n_heads = args.num_attention_heads
|
||||||
|
self.rope_theta = args.rope_theta
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.num_key_value_heads = args.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
rope_scale = (
|
||||||
|
1 / args.rope_scaling["factor"]
|
||||||
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
dims=self.head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=self.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
):
|
||||||
|
B, L, _ = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
attn_output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output)
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_depth = args.scale_depth
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MiniCPMModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = MiniCPMModel(args)
|
||||||
|
|
||||||
|
if not self.args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
|
||||||
|
if not self.args.tie_word_embeddings:
|
||||||
|
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
||||||
|
else:
|
||||||
|
out = out @ self.model.embed_tokens.weight.T
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
if "lm_head.weight" not in weights:
|
||||||
|
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||||
|
return weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
227
llms/mlx_lm/models/mixtral.py
Normal file
227
llms/mlx_lm/models/mixtral.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
vocab_size: int = 32000
|
||||||
|
hidden_size: int = 4096
|
||||||
|
intermediate_size: int = 14336
|
||||||
|
num_hidden_layers: int = 32
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_experts_per_tok: int = 2
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
num_local_experts: int = 8
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 1e6
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralAttention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.num_heads = args.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = args.num_key_value_heads
|
||||||
|
self.rope_theta = args.rope_theta
|
||||||
|
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
self.head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralSparseMoeBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = args.hidden_size
|
||||||
|
self.ffn_dim = args.intermediate_size
|
||||||
|
self.num_experts = args.num_local_experts
|
||||||
|
self.num_experts_per_tok = args.num_experts_per_tok
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||||
|
|
||||||
|
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
gates = self.gate(x)
|
||||||
|
|
||||||
|
k = self.num_experts_per_tok
|
||||||
|
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||||
|
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
|
y = self.switch_mlp(x, inds)
|
||||||
|
y = (y * scores[..., None]).sum(axis=-2)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
|
||||||
|
self.self_attn = MixtralAttention(args)
|
||||||
|
|
||||||
|
self.block_sparse_moe = MixtralSparseMoeBlock(args)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.block_sparse_moe(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
T = h.shape[1]
|
||||||
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = MixtralModel(args)
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
return self.lm_head(out)
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
|
||||||
|
return weights
|
||||||
|
for l in range(self.args.num_hidden_layers):
|
||||||
|
prefix = f"model.layers.{l}"
|
||||||
|
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||||
|
for k in ["weight", "scales", "biases"]:
|
||||||
|
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
|
||||||
|
to_join = [
|
||||||
|
weights.pop(
|
||||||
|
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
|
||||||
|
)
|
||||||
|
for e in range(self.args.num_local_experts)
|
||||||
|
]
|
||||||
|
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
|
||||||
|
mx.stack(to_join)
|
||||||
|
)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
185
llms/mlx_lm/models/olmo.py
Normal file
185
llms/mlx_lm/models/olmo.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from sys import exit
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
try:
|
||||||
|
import hf_olmo
|
||||||
|
except ImportError:
|
||||||
|
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
d_model: int
|
||||||
|
n_layers: int
|
||||||
|
mlp_hidden_size: int
|
||||||
|
n_heads: int
|
||||||
|
vocab_size: int
|
||||||
|
embedding_size: int
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
mlp_ratio: int = 4
|
||||||
|
weight_tying: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.mlp_hidden_size = (
|
||||||
|
self.mlp_hidden_size
|
||||||
|
if self.mlp_hidden_size is not None
|
||||||
|
else self.mlp_ratio * self.d_model
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = args.n_heads
|
||||||
|
dim = args.d_model
|
||||||
|
|
||||||
|
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
|
||||||
|
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
|
||||||
|
|
||||||
|
self.att_norm = nn.LayerNorm(dim, affine=False)
|
||||||
|
self.ff_norm = nn.LayerNorm(dim, affine=False)
|
||||||
|
|
||||||
|
head_dim = dim // self.n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.att_proj = nn.Linear(dim, 3 * dim, bias=False)
|
||||||
|
self.attn_out = nn.Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def attend(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||||
|
if mask is not None:
|
||||||
|
scores += mask
|
||||||
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
|
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.attn_out(output)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.attend(self.att_norm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
|
||||||
|
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
||||||
|
|
||||||
|
out = h + self.ff_out(nn.silu(x2) * x1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_layers = args.n_layers
|
||||||
|
self.weight_tying = args.weight_tying
|
||||||
|
|
||||||
|
self.wte = nn.Embedding(args.embedding_size, args.d_model)
|
||||||
|
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||||
|
if not self.weight_tying:
|
||||||
|
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||||
|
self.norm = nn.LayerNorm(args.d_model, affine=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.wte(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.blocks)
|
||||||
|
|
||||||
|
for block, c in zip(self.blocks, cache):
|
||||||
|
h = block(h, mask, c)
|
||||||
|
|
||||||
|
h = self.norm(h)
|
||||||
|
|
||||||
|
if self.weight_tying:
|
||||||
|
return self.wte.as_linear(h), cache
|
||||||
|
|
||||||
|
return self.ff_out(h)
|
||||||
|
|
||||||
|
|
||||||
|
class OlmoModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.transformer = Transformer(args)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
return self.transformer(inputs, cache)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = OlmoModel(args)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
return self.model(inputs, cache)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.transformer.blocks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.d_model // self.args.n_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.n_heads
|
||||||
229
llms/mlx_lm/models/openelm.py
Normal file
229
llms/mlx_lm/models/openelm.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
head_dim: int
|
||||||
|
num_transformer_layers: int
|
||||||
|
model_dim: int
|
||||||
|
vocab_size: int
|
||||||
|
ffn_dim_divisor: int
|
||||||
|
num_query_heads: List
|
||||||
|
num_kv_heads: List
|
||||||
|
ffn_multipliers: List
|
||||||
|
ffn_with_glu: bool = True
|
||||||
|
normalize_qk_projections: bool = True
|
||||||
|
share_input_output_layers: bool = True
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_freq_constant: float = 10000
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(
|
||||||
|
v: Union[float, int],
|
||||||
|
divisor: Optional[int] = 8,
|
||||||
|
min_value: Optional[Union[float, int]] = None,
|
||||||
|
) -> Union[float, int]:
|
||||||
|
"""
|
||||||
|
This function is taken from the original tf repo.
|
||||||
|
It ensures that all layers have a channel number that is divisible by the divisor
|
||||||
|
It can be seen at:
|
||||||
|
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
|
||||||
|
Args:
|
||||||
|
v: input value
|
||||||
|
divisor: default to 8
|
||||||
|
min_value: minimum divisor value
|
||||||
|
Returns:
|
||||||
|
new_v: new divisible value
|
||||||
|
"""
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs, layer_id: int):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim = args.head_dim
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.model_dim = model_dim = args.model_dim
|
||||||
|
|
||||||
|
self.n_heads = n_heads = args.num_query_heads[layer_id]
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id]
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
op_size = (n_heads + (n_kv_heads * 2)) * head_dim
|
||||||
|
self.qkv_proj = nn.Linear(model_dim, op_size, bias=False)
|
||||||
|
self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False)
|
||||||
|
|
||||||
|
self.normalize_qk_projections = args.normalize_qk_projections
|
||||||
|
|
||||||
|
if self.normalize_qk_projections:
|
||||||
|
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||||
|
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_proj(x)
|
||||||
|
|
||||||
|
qkv = qkv.reshape(
|
||||||
|
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
|
||||||
|
).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
queries, keys, values = mx.split(
|
||||||
|
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
if self.normalize_qk_projections:
|
||||||
|
queries = self.q_norm(queries)
|
||||||
|
keys = self.k_norm(keys)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.out_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs, layer_id: int):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
dim = args.model_dim
|
||||||
|
ffn_multiplier = args.ffn_multipliers[layer_id]
|
||||||
|
|
||||||
|
intermediate_dim = int(
|
||||||
|
make_divisible(
|
||||||
|
ffn_multiplier * args.model_dim,
|
||||||
|
divisor=args.ffn_dim_divisor,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False)
|
||||||
|
self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
x = self.proj_1(x)
|
||||||
|
gate, x = mx.split(x, 2, axis=-1)
|
||||||
|
return self.proj_2(nn.silu(gate) * x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs, layer_id: int):
|
||||||
|
super().__init__()
|
||||||
|
dim = args.model_dim
|
||||||
|
self.attn = Attention(args, layer_id=layer_id)
|
||||||
|
self.ffn = MLP(args, layer_id=layer_id)
|
||||||
|
self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
|
||||||
|
self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.attn(self.attn_norm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.ffn(self.ffn_norm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class OpenELMModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_transformer_layers = args.num_transformer_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args, layer_id=layer_id)
|
||||||
|
for layer_id in range(self.num_transformer_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.token_embeddings(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.transformer = OpenELMModel(args)
|
||||||
|
if not args.share_input_output_layers:
|
||||||
|
self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.transformer(inputs, cache)
|
||||||
|
if self.args.share_input_output_layers:
|
||||||
|
out = self.transformer.token_embeddings.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.transformer.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.head_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_kv_heads
|
||||||
182
llms/mlx_lm/models/openlm.py
Normal file
182
llms/mlx_lm/models/openlm.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, create_additive_causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParamsArgs(BaseModelArgs):
|
||||||
|
dim: int
|
||||||
|
ffn_type: str
|
||||||
|
n_heads: int
|
||||||
|
n_layers: int
|
||||||
|
norm_eps: float
|
||||||
|
positional_embedding_type: str
|
||||||
|
post_embed_norm: bool
|
||||||
|
qk_norm: bool
|
||||||
|
vocab_size: int
|
||||||
|
weight_tying: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
params_args_dict: ParamsArgs
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = args.dim
|
||||||
|
self.n_heads = args.n_heads
|
||||||
|
self.head_dim = self.dim // self.n_heads
|
||||||
|
self.qk_norm = args.qk_norm
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.in_proj = nn.Linear(self.dim, 3 * self.dim, bias=False)
|
||||||
|
self.out_proj = nn.Linear(self.dim, self.dim, bias=False)
|
||||||
|
if self.qk_norm:
|
||||||
|
self.q_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
|
||||||
|
self.k_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
self.head_dim,
|
||||||
|
traditional=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.in_proj(x).split(3, axis=-1)
|
||||||
|
|
||||||
|
if self.qk_norm:
|
||||||
|
queries = self.q_norm(queries)
|
||||||
|
keys = self.q_norm(keys)
|
||||||
|
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.out_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
# https://github.com/mlfoundations/open_lm/blob/c65b43042ff31c0fe26f930decf1ccab1b03ab4b/open_lm/model.py#L254C2-L254C3
|
||||||
|
hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
|
||||||
|
self.w12 = nn.Linear(args.dim, 2 * hidden_dim, bias=False)
|
||||||
|
self.w3 = nn.Linear(hidden_dim, args.dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
gate, x = self.w12(x).split(2, axis=-1)
|
||||||
|
return self.w3(nn.silu(gate) * x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Attention(args)
|
||||||
|
self.feed_forward = MLP(args)
|
||||||
|
self.ffn_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
|
||||||
|
self.attention_norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.attention(self.attention_norm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.feed_forward(self.ffn_norm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class OpenLM(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||||
|
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||||
|
self.norm = nn.LayerNorm(args.dim, eps=args.norm_eps, bias=False)
|
||||||
|
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
_, L = inputs.shape
|
||||||
|
|
||||||
|
h = self.tok_embeddings(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = create_additive_causal_mask(
|
||||||
|
h.shape[1], cache[0].offset if cache is not None else 0
|
||||||
|
)
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.output(self.norm(h))
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
args.params_args_dict = ParamsArgs.from_dict(args.params_args_dict)
|
||||||
|
self.args = args.params_args_dict
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = OpenLM(self.args)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
# Remove unused precomputed rotary freqs
|
||||||
|
return {k: v for k, v in weights.items() if "inv_freq" not in k}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.dim // self.args.n_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.n_heads
|
||||||
180
llms/mlx_lm/models/phi.py
Normal file
180
llms/mlx_lm/models/phi.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str = "phi"
|
||||||
|
max_position_embeddings: int = 2048
|
||||||
|
vocab_size: int = 51200
|
||||||
|
hidden_size: int = 2560
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_hidden_layers: int = 32
|
||||||
|
num_key_value_heads: int = 32
|
||||||
|
partial_rotary_factor: float = 0.4
|
||||||
|
intermediate_size: int = 10240
|
||||||
|
layer_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
|
class PhiAttention(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.repeats = self.num_heads // self.num_key_value_heads
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.partial_rotary_factor = config.partial_rotary_factor
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||||
|
)
|
||||||
|
self.dense = nn.Linear(
|
||||||
|
self.num_heads * self.head_dim, self.hidden_size, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
int(self.partial_rotary_factor * self.head_dim),
|
||||||
|
traditional=False,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Extract some shapes
|
||||||
|
B, L, D = queries.shape
|
||||||
|
n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(
|
||||||
|
B,
|
||||||
|
L,
|
||||||
|
n_heads,
|
||||||
|
-1,
|
||||||
|
).moveaxis(1, 2)
|
||||||
|
keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||||
|
values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||||
|
|
||||||
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||||
|
).astype(values.dtype)
|
||||||
|
|
||||||
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.dense(output)
|
||||||
|
|
||||||
|
|
||||||
|
class PhiMLP(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.act = nn.GELU(approx="precise")
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.fc2(self.act(self.fc1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class PhiDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = PhiAttention(config=config)
|
||||||
|
self.input_layernorm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.mlp = PhiMLP(config)
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache):
|
||||||
|
h = self.input_layernorm(x)
|
||||||
|
attn_h = self.self_attn(h, mask, cache)
|
||||||
|
ff_h = self.mlp(h)
|
||||||
|
return attn_h + ff_h + x
|
||||||
|
|
||||||
|
|
||||||
|
class PhiModel(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||||
|
self.final_layernorm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, cache):
|
||||||
|
x = self.embed_tokens(x)
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if x.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
x = layer(x, mask, c)
|
||||||
|
return self.final_layernorm(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = config.model_type
|
||||||
|
self.model = PhiModel(config)
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||||
|
self.args = config
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
cache: mx.array = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
y = self.model(x, cache)
|
||||||
|
return self.lm_head(y)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
213
llms/mlx_lm/models/phi3.py
Normal file
213
llms/mlx_lm/models/phi3.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
max_position_embeddings: int = 131072
|
||||||
|
original_max_position_embeddings: int = 4096
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
if self.rope_scaling:
|
||||||
|
required_keys = {"long_factor", "type"}
|
||||||
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
|
if self.rope_scaling["type"] not in ["su", "linear"]:
|
||||||
|
print(
|
||||||
|
"[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false."
|
||||||
|
)
|
||||||
|
self.rope_scaling = None
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
||||||
|
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
|
rope_scale = 1.0
|
||||||
|
if args.rope_scaling and args.rope_scaling["type"] == "su":
|
||||||
|
self.rope = SuScaledRotaryEmbedding(
|
||||||
|
head_dim,
|
||||||
|
traditional=False,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
|
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||||
|
short_factor=args.rope_scaling["short_factor"],
|
||||||
|
long_factor=args.rope_scaling["long_factor"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||||
|
rope_scale = 1 / args.rope_scaling["factor"]
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_proj(x)
|
||||||
|
query_pos = self.n_heads * self.head_dim
|
||||||
|
queries, keys, values = mx.split(
|
||||||
|
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
x = self.gate_up_proj(x)
|
||||||
|
gate, x = mx.split(x, 2, axis=-1)
|
||||||
|
return self.down_proj(nn.silu(gate) * x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = Phi3Model(args)
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
return self.lm_head(out)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
318
llms/mlx_lm/models/phi3small.py
Normal file
318
llms/mlx_lm/models/phi3small.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
dense_attention_every_n_layers: int
|
||||||
|
ff_intermediate_size: int
|
||||||
|
gegelu_limit: float
|
||||||
|
num_hidden_layers: int
|
||||||
|
num_attention_heads: int
|
||||||
|
layer_norm_epsilon: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
mup_attn_multiplier: float = 1.0
|
||||||
|
mup_use_scaling: bool = True
|
||||||
|
mup_embedding_multiplier: float = 10.0
|
||||||
|
mup_width_multiplier: float = 8.0
|
||||||
|
rope_embedding_base: float = 1000000
|
||||||
|
rope_position_scale: float = 1.0
|
||||||
|
blocksparse_block_size: int = (64,)
|
||||||
|
blocksparse_num_local_blocks: int = 16
|
||||||
|
blocksparse_vert_stride: int = 8
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def gegelu_impl(a_gelu, a_linear, limit):
|
||||||
|
a_gelu = mx.where(
|
||||||
|
mx.isinf(a_gelu),
|
||||||
|
a_gelu,
|
||||||
|
mx.clip(a_gelu, a_min=None, a_max=limit),
|
||||||
|
)
|
||||||
|
a_linear = mx.where(
|
||||||
|
mx.isinf(a_linear),
|
||||||
|
a_linear,
|
||||||
|
mx.clip(a_linear, a_min=-limit, a_max=limit),
|
||||||
|
)
|
||||||
|
out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
|
||||||
|
return out_gelu * (a_linear + 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def gegelu(x, limit):
|
||||||
|
a_gelu, a_linear = x[..., ::2], x[..., 1::2]
|
||||||
|
return gegelu_impl(a_gelu, a_linear, limit)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
self.n_q_per_kv = n_heads // n_kv_heads
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
|
|
||||||
|
self.query_key_value = nn.Linear(
|
||||||
|
dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
|
||||||
|
)
|
||||||
|
self.dense = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
if args.mup_use_scaling:
|
||||||
|
norm_factor = head_dim / args.mup_attn_multiplier
|
||||||
|
else:
|
||||||
|
norm_factor = math.sqrt(head_dim)
|
||||||
|
self.scale = 1.0 / norm_factor
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=False,
|
||||||
|
base=args.rope_embedding_base,
|
||||||
|
scale=args.rope_position_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_idx % args.dense_attention_every_n_layers == 0:
|
||||||
|
self.block_sparse = True
|
||||||
|
self.blocksparse_block_size = args.blocksparse_block_size
|
||||||
|
if self.blocksparse_block_size not in (32, 64):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported block size {self.blocksparse_block_size}"
|
||||||
|
)
|
||||||
|
self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
|
||||||
|
self.blocksparse_vert_stride = args.blocksparse_vert_stride
|
||||||
|
else:
|
||||||
|
self.block_sparse = False
|
||||||
|
|
||||||
|
def _block_sparse_mask(self, q_len, kv_len):
|
||||||
|
vert_stride = self.blocksparse_vert_stride
|
||||||
|
local_blocks = self.blocksparse_num_local_blocks
|
||||||
|
block_size = self.blocksparse_block_size
|
||||||
|
n_heads = self.n_heads
|
||||||
|
|
||||||
|
kv_blocks = (kv_len + block_size - 1) // block_size
|
||||||
|
q_blocks = (q_len + block_size - 1) // block_size
|
||||||
|
q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
|
||||||
|
k_pos = mx.arange(kv_blocks)[None, None]
|
||||||
|
|
||||||
|
mask_vert_strided = (
|
||||||
|
mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
|
||||||
|
) % vert_stride
|
||||||
|
mask_vert_strided = (mask_vert_strided == 0)[:, None, :]
|
||||||
|
|
||||||
|
block_mask = (q_pos >= k_pos) & (
|
||||||
|
(q_pos - k_pos < local_blocks) | mask_vert_strided
|
||||||
|
)
|
||||||
|
block_mask = block_mask.reshape(
|
||||||
|
self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
|
||||||
|
)
|
||||||
|
dense_mask = mx.repeat(
|
||||||
|
mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
|
||||||
|
)
|
||||||
|
return block_mask, dense_mask[..., -q_len:, :kv_len]
|
||||||
|
|
||||||
|
def _block_sparse_attention(self, queries, keys, values, scale, mask):
|
||||||
|
queries = scale * queries
|
||||||
|
B = queries.shape[0]
|
||||||
|
L = queries.shape[2]
|
||||||
|
queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
|
||||||
|
keys = mx.expand_dims(keys, 2)
|
||||||
|
values = mx.expand_dims(values, 2)
|
||||||
|
|
||||||
|
# TODO get rid of dense mask if we have a fill value
|
||||||
|
block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
|
||||||
|
scores = queries @ mx.swapaxes(keys, -1, -2)
|
||||||
|
# TODO, uncomment when faster
|
||||||
|
# scores = mx.block_masked_mm(
|
||||||
|
# queries,
|
||||||
|
# mx.swapaxes(keys, -1, -2),
|
||||||
|
# mask_out=block_mask,
|
||||||
|
# block_size=self.blocksparse_block_size,
|
||||||
|
# )
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
scores = scores + mx.where(
|
||||||
|
dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
|
||||||
|
)
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
|
output = scores @ values
|
||||||
|
# TODO, uncomment when faster
|
||||||
|
# output = mx.block_masked_mm(
|
||||||
|
# scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
|
||||||
|
# )
|
||||||
|
return mx.reshape(output, (B, self.n_heads, L, -1))
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
qkv = self.query_key_value(x)
|
||||||
|
qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
|
||||||
|
queries = qkv[..., :-2, :].flatten(-3, -2)
|
||||||
|
keys = qkv[..., -2, :]
|
||||||
|
values = qkv[..., -1, :]
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.transpose(0, 2, 1, 3)
|
||||||
|
values = values.transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
if self.block_sparse:
|
||||||
|
output = self._block_sparse_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.dense(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
dim = args.hidden_size
|
||||||
|
hidden_dim = args.ff_intermediate_size
|
||||||
|
self.gegelu_limit = args.gegelu_limit
|
||||||
|
self.up_proj = nn.Linear(dim, 2 * hidden_dim)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
x = self.up_proj(x)
|
||||||
|
return self.down_proj(gegelu(x, self.gegelu_limit))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args, layer_idx)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
self.input_layernorm = nn.LayerNorm(
|
||||||
|
args.hidden_size, eps=args.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
|
args.hidden_size,
|
||||||
|
eps=args.layer_norm_epsilon,
|
||||||
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.mup_embedding_multiplier = args.mup_embedding_multiplier
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args, layer_idx=l)
|
||||||
|
for l in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.final_layernorm = nn.LayerNorm(
|
||||||
|
args.hidden_size, eps=args.layer_norm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
if self.mup_embedding_multiplier:
|
||||||
|
h = self.mup_embedding_multiplier * h
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.final_layernorm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = Phi3Model(args)
|
||||||
|
self.args = args
|
||||||
|
self.mup_width_multiplier = args.mup_width_multiplier
|
||||||
|
self._dummy_tokenizer_ids = mx.array(
|
||||||
|
[100256, 100258, 100259, 100260, 100264, 100265]
|
||||||
|
+ list(range(100267, 100352))
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
if self.mup_width_multiplier:
|
||||||
|
out = out / self.mup_width_multiplier
|
||||||
|
out[self._dummy_tokenizer_ids] = -float("inf")
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
# Remove unused precomputed rotary freqs
|
||||||
|
return {
|
||||||
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
203
llms/mlx_lm/models/phixtral.py
Normal file
203
llms/mlx_lm/models/phixtral.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
import inspect
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .switch_layers import SwitchMLP
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs:
|
||||||
|
model_type: str
|
||||||
|
num_vocab: int = 51200
|
||||||
|
model_dim: int = 2560
|
||||||
|
num_heads: int = 32
|
||||||
|
num_layers: int = 32
|
||||||
|
rotary_dim: int = 32
|
||||||
|
num_experts_per_tok: int = 2
|
||||||
|
num_local_experts: int = 4
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in params.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RoPEAttention(nn.Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||||
|
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||||
|
self.out_proj = nn.Linear(dims, dims)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
qkv = self.Wqkv(x)
|
||||||
|
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||||
|
|
||||||
|
# Extract some shapes
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, D = queries.shape
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
queries = queries.astype(mx.float32)
|
||||||
|
|
||||||
|
# Finally perform the attention computation
|
||||||
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||||
|
).astype(values.dtype)
|
||||||
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.out_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MOE(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.num_experts = args.num_local_experts
|
||||||
|
self.num_experts_per_tok = args.num_experts_per_tok
|
||||||
|
self.switch_mlp = SwitchMLP(
|
||||||
|
self.dim, self.hidden_dim, self.num_experts, bias=True
|
||||||
|
)
|
||||||
|
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
gates = self.gate(x)
|
||||||
|
|
||||||
|
k = self.num_experts_per_tok
|
||||||
|
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
|
||||||
|
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
|
y = self.switch_mlp(x, inds)
|
||||||
|
y = (y * scores[..., None]).sum(axis=-2)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelBlock(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
dims = config.model_dim
|
||||||
|
mlp_dims = dims * 4
|
||||||
|
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||||
|
self.ln = nn.LayerNorm(dims)
|
||||||
|
self.moe = MOE(config, dims, mlp_dims)
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache):
|
||||||
|
h = self.ln(x)
|
||||||
|
attn_h = self.mixer(h, mask, cache)
|
||||||
|
ff_h = self.moe(h)
|
||||||
|
return attn_h + ff_h + x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.embd = Embd(config)
|
||||||
|
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache):
|
||||||
|
x = self.embd(x)
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
|
for layer, c in zip(self.h, cache):
|
||||||
|
x = layer(x, mask, c)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Embd(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.wte(x)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputHead(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.ln = nn.LayerNorm(config.model_dim)
|
||||||
|
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
return self.linear(self.ln(inputs))
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = config.model_type
|
||||||
|
self.transformer = TransformerDecoder(config)
|
||||||
|
self.lm_head = OutputHead(config)
|
||||||
|
self.args = config
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache: mx.array = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
mask = None
|
||||||
|
if x.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
|
y = self.transformer(x, mask, cache)
|
||||||
|
return self.lm_head(y)
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
|
||||||
|
return weights
|
||||||
|
for l in range(self.args.num_layers):
|
||||||
|
prefix = f"transformer.h.{l}"
|
||||||
|
for n in ["fc1", "fc2"]:
|
||||||
|
for k in ["weight", "scales", "biases", "bias"]:
|
||||||
|
if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
|
||||||
|
to_join = [
|
||||||
|
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
|
||||||
|
for e in range(self.args.num_local_experts)
|
||||||
|
]
|
||||||
|
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.transformer.h
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.model_dim // self.args.num_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_heads
|
||||||
216
llms/mlx_lm/models/plamo.py
Normal file
216
llms/mlx_lm/models/plamo.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
n_shared_head: int = 8
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
head_dim = self.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
|
self.q_num_heads = config.num_attention_heads
|
||||||
|
self.qk_dim = self.v_dim = head_dim
|
||||||
|
self.k_num_heads = self.v_num_heads = int(
|
||||||
|
np.ceil(self.q_num_heads / config.n_shared_head)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.q_num_heads * self.qk_dim, bias=False
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.k_num_heads * self.qk_dim, bias=False
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.v_num_heads * self.v_dim, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||||
|
)
|
||||||
|
self.rotary_emb = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=config.rope_traditional,
|
||||||
|
base=config.rope_theta,
|
||||||
|
scale=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||||
|
bsz, q_len, _ = hidden_states.shape
|
||||||
|
|
||||||
|
queries = self.q_proj(hidden_states)
|
||||||
|
keys = self.k_proj(hidden_states)
|
||||||
|
values = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||||
|
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rotary_emb(queries)
|
||||||
|
keys = self.rotary_emb(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
scale=self.scale,
|
||||||
|
mask=attention_mask,
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class PlamoDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = Attention(config)
|
||||||
|
self.mlp = MLP(config)
|
||||||
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> Tuple[Any, ...]:
|
||||||
|
# from LlamaDecoder
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states_sa = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states_mlp = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states_sa + hidden_states_mlp
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PlamoDecoder(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
PlamoDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PlamoModel(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
self.layers = PlamoDecoder(config) # type: ignore
|
||||||
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
|
||||||
|
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None for _ in range(len(self.layers.layers))]
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = PlamoModel(args)
|
||||||
|
self.lm_head: nn.Module = nn.Linear(
|
||||||
|
args.hidden_size, args.vocab_size, bias=False
|
||||||
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
return self.lm_head(out)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_attention_heads // self.args.n_shared_head
|
||||||
169
llms/mlx_lm/models/qwen.py
Normal file
169
llms/mlx_lm/models/qwen.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int = 2048
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_hidden_layers: int = 24
|
||||||
|
kv_channels: int = 128
|
||||||
|
max_position_embeddings: int = 8192
|
||||||
|
layer_norm_epsilon: float = 1e-6
|
||||||
|
intermediate_size: int = 11008
|
||||||
|
no_bias: bool = True
|
||||||
|
vocab_size: int = 151936
|
||||||
|
num_key_value_heads = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hidden_size = args.hidden_size
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
|
||||||
|
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
|
||||||
|
|
||||||
|
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
|
||||||
|
|
||||||
|
proj_size = args.kv_channels * self.num_attention_heads
|
||||||
|
|
||||||
|
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
|
||||||
|
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
|
||||||
|
|
||||||
|
self.scale = hidden_size_per_attention_head**-0.5
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
qkv = self.c_attn(x)
|
||||||
|
|
||||||
|
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||||
|
|
||||||
|
B, L, _ = q.shape
|
||||||
|
|
||||||
|
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||||
|
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rotary_emb(queries)
|
||||||
|
keys = self.rotary_emb(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.c_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.w1 = nn.Linear(
|
||||||
|
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||||
|
)
|
||||||
|
self.w2 = nn.Linear(
|
||||||
|
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||||
|
)
|
||||||
|
self.c_proj = nn.Linear(
|
||||||
|
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
a1 = self.w1(x)
|
||||||
|
a2 = self.w2(x)
|
||||||
|
return self.c_proj(a1 * nn.silu(a2))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
self.attn = Attention(args)
|
||||||
|
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
residual = x
|
||||||
|
x = self.ln_1(x)
|
||||||
|
x = self.attn(x, mask=mask, cache=cache)
|
||||||
|
residual = x + residual
|
||||||
|
x = self.ln_2(residual)
|
||||||
|
x = self.mlp(x)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class QwenModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||||
|
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def __call__(self, inputs, mask=None, cache=None):
|
||||||
|
x = self.wte(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
T = x.shape[1]
|
||||||
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
|
for layer, c in zip(self.h, cache):
|
||||||
|
x = layer(x, mask, c)
|
||||||
|
|
||||||
|
return self.ln_f(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = config.model_type
|
||||||
|
self.transformer = QwenModel(config)
|
||||||
|
self.lm_head = nn.Linear(
|
||||||
|
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
||||||
|
)
|
||||||
|
self.args = config
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache: mx.array = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
y = self.transformer(x, mask, cache)
|
||||||
|
return self.lm_head(y)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.transformer.h
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_attention_heads
|
||||||
206
llms/mlx_lm/models/qwen2.py
Normal file
206
llms/mlx_lm/models/qwen2.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
rope_theta: float = 1000000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
if self.rope_scaling:
|
||||||
|
required_keys = {"factor", "type"}
|
||||||
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
|
if self.rope_scaling["type"] != "linear":
|
||||||
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
|
rope_scale = (
|
||||||
|
1 / args.rope_scaling["factor"]
|
||||||
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = Qwen2Model(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
weights.pop("lm_head.weight", None)
|
||||||
|
# Remove unused precomputed rotary freqs
|
||||||
|
return {
|
||||||
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
246
llms/mlx_lm/models/qwen2_moe.py
Normal file
246
llms/mlx_lm/models/qwen2_moe.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
num_experts_per_tok: int
|
||||||
|
num_experts: int
|
||||||
|
moe_intermediate_size: int
|
||||||
|
shared_expert_intermediate_size: int
|
||||||
|
rms_norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int = None
|
||||||
|
rope_theta: float = 1000000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
if self.rope_scaling:
|
||||||
|
required_keys = {"factor", "type"}
|
||||||
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
|
if self.rope_scaling["type"] != "linear":
|
||||||
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
head_dim = args.hidden_size // n_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
dim = args.hidden_size
|
||||||
|
intermediate_size = args.moe_intermediate_size
|
||||||
|
shared_expert_intermediate_size = args.shared_expert_intermediate_size
|
||||||
|
|
||||||
|
self.num_experts = num_experts = args.num_experts
|
||||||
|
self.top_k = args.num_experts_per_tok
|
||||||
|
|
||||||
|
self.gate = nn.Linear(dim, num_experts, bias=False)
|
||||||
|
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
|
||||||
|
|
||||||
|
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
|
||||||
|
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
):
|
||||||
|
gates = self.gate(x)
|
||||||
|
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||||
|
|
||||||
|
k = self.top_k
|
||||||
|
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||||
|
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||||
|
|
||||||
|
y = self.switch_mlp(x, inds)
|
||||||
|
y = (y * scores[..., None]).sum(axis=-2)
|
||||||
|
|
||||||
|
shared_expert_output = self.shared_expert(x)
|
||||||
|
shared_expert_output = (
|
||||||
|
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
|
||||||
|
)
|
||||||
|
|
||||||
|
return y + shared_expert_output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = Qwen2MoeSparseMoeBlock(args)
|
||||||
|
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
Qwen2MoeDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = Qwen2MoeModel(args)
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
return self.lm_head(out)
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||||
|
return weights
|
||||||
|
for l in range(self.args.num_hidden_layers):
|
||||||
|
prefix = f"model.layers.{l}"
|
||||||
|
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||||
|
for k in ["weight", "scales", "biases"]:
|
||||||
|
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||||
|
to_join = [
|
||||||
|
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||||
|
for e in range(self.args.num_experts)
|
||||||
|
]
|
||||||
|
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
219
llms/mlx_lm/models/stablelm.py
Normal file
219
llms/mlx_lm/models/stablelm.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
intermediate_size: int
|
||||||
|
rope_theta: float
|
||||||
|
use_qkv_bias: bool
|
||||||
|
partial_rotary_factor: float
|
||||||
|
layer_norm_eps: float
|
||||||
|
use_parallel_residual: bool = False
|
||||||
|
qk_layernorm: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormPerHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, head_dim, num_heads, eps):
|
||||||
|
super().__init__()
|
||||||
|
self.norms = [
|
||||||
|
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
|
||||||
|
]
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
w = mx.stack([n.weight for n in self.norms])
|
||||||
|
return w * mx.fast.layer_norm(x, None, None, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.partial_rotary_factor = config.partial_rotary_factor
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_key_value_heads * self.head_dim,
|
||||||
|
bias=config.use_qkv_bias,
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_key_value_heads * self.head_dim,
|
||||||
|
bias=config.use_qkv_bias,
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
int(self.partial_rotary_factor * self.head_dim),
|
||||||
|
traditional=False,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.qk_layernorm = config.qk_layernorm
|
||||||
|
if self.qk_layernorm:
|
||||||
|
self.q_layernorm = LayerNormPerHead(
|
||||||
|
self.head_dim, self.num_heads, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.k_layernorm = LayerNormPerHead(
|
||||||
|
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Extract some shapes
|
||||||
|
B, L, D = queries.shape
|
||||||
|
|
||||||
|
queries = queries.reshape(B, L, self.num_heads, -1)
|
||||||
|
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
|
||||||
|
if self.qk_layernorm:
|
||||||
|
queries = self.q_layernorm(queries)
|
||||||
|
keys = self.k_layernorm(keys)
|
||||||
|
queries = queries.transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
queries = queries.astype(mx.float32)
|
||||||
|
keys = keys.astype(mx.float32)
|
||||||
|
|
||||||
|
# Finally perform the attention computation
|
||||||
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=scale, mask=mask
|
||||||
|
).astype(values.dtype)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = Attention(config=config)
|
||||||
|
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||||
|
self.input_layernorm = nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
|
if not self.use_parallel_residual:
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache):
|
||||||
|
h = self.input_layernorm(x)
|
||||||
|
r = self.self_attn(h, mask, cache)
|
||||||
|
|
||||||
|
if self.use_parallel_residual:
|
||||||
|
out = x + r + self.mlp(h)
|
||||||
|
else:
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class StableLM(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||||
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def __call__(self, x, mask, cache):
|
||||||
|
x = self.embed_tokens(x)
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
x = layer(x, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_type = config.model_type
|
||||||
|
self.model = StableLM(config)
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
self.args = config
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
|
cache: mx.array = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
mask = None
|
||||||
|
if x.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
|
y = self.model(x, mask, cache)
|
||||||
|
return self.lm_head(y)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
175
llms/mlx_lm/models/starcoder2.py
Normal file
175
llms/mlx_lm/models/starcoder2.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
norm_epsilon: float = 1e-5
|
||||||
|
vocab_size: int = 49152
|
||||||
|
rope_theta: float = 100000
|
||||||
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
head_dim = args.hidden_size // args.num_attention_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
||||||
|
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, D = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
|
||||||
|
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.n_heads = args.num_attention_heads
|
||||||
|
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||||
|
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
|
args.hidden_size, eps=args.norm_epsilon
|
||||||
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = Starcoder2Model(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
79
llms/mlx_lm/models/su_rope.py
Normal file
79
llms/mlx_lm/models/su_rope.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
class SuScaledRotaryEmbedding:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000.0,
|
||||||
|
scale: float = 1.0,
|
||||||
|
max_position_embeddings: int = 131072,
|
||||||
|
original_max_position_embeddings: int = 4096,
|
||||||
|
short_factor: Union[List[float], float] = 1.0,
|
||||||
|
long_factor: Union[List[float], float] = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dims (int): The feature dimensions to be rotated.
|
||||||
|
traditional (bool, optional): Unused. Default: ``False``.
|
||||||
|
base (int, optional): Base for the exponential scaling.
|
||||||
|
scale (float, optional): The scale used to scale the positions.
|
||||||
|
Default: ``1.0``.
|
||||||
|
max_position_embeddings (int, optional): The maximum sequence
|
||||||
|
length that this model was trained with. This is used to determine
|
||||||
|
the size of the original RoPE embeddings when using long scaling.
|
||||||
|
Default: ``131072``.
|
||||||
|
original_max_position_embeddings (int, optional): The maximum
|
||||||
|
sequence length that this model was trained with. This is used to
|
||||||
|
determine the size of the original RoPE embeddings when using long
|
||||||
|
scaling. Default: ``4096``.
|
||||||
|
short_factor (float or list[float], optional): List of scaling
|
||||||
|
factors for sequences of length lesser than
|
||||||
|
``original_max_position_embeddings``. Default: ``1.0``.
|
||||||
|
long_factor (float or list[float], optional): List of scaling
|
||||||
|
factors for sequences of length greater than
|
||||||
|
``original_max_position_embeddings``. Default: ``1.0``.
|
||||||
|
"""
|
||||||
|
self.inv_freq_short = 1.0 / (
|
||||||
|
mx.array(short_factor, dtype=mx.float32)
|
||||||
|
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
|
)
|
||||||
|
self.inv_freq_long = 1.0 / (
|
||||||
|
scale
|
||||||
|
* mx.array(long_factor, dtype=mx.float32)
|
||||||
|
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
|
)
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
self.scaling_factor = math.sqrt(
|
||||||
|
1
|
||||||
|
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||||
|
/ math.log(original_max_position_embeddings)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_cos_sin(self, offset, L):
|
||||||
|
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
|
||||||
|
inv_freq = (
|
||||||
|
self.inv_freq_long
|
||||||
|
if (offset + L) > self.original_max_position_embeddings
|
||||||
|
else self.inv_freq_short
|
||||||
|
)
|
||||||
|
freqs = position_ids[:, None] * inv_freq[None, :]
|
||||||
|
emb = mx.concatenate([freqs, freqs], axis=-1)
|
||||||
|
cos = mx.cos(emb) * self.scaling_factor
|
||||||
|
sin = mx.sin(emb) * self.scaling_factor
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
def _rotate_half(_x):
|
||||||
|
midpoint = _x.shape[-1] // 2
|
||||||
|
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
|
||||||
|
return mx.concatenate([-x2, x1], axis=-1)
|
||||||
|
|
||||||
|
cos, sin = self._get_cos_sin(offset, x.shape[2])
|
||||||
|
return (x * cos) + (_rotate_half(x) * sin)
|
||||||
165
llms/mlx_lm/models/switch_layers.py
Normal file
165
llms/mlx_lm/models/switch_layers.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedSwitchLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
num_experts: int,
|
||||||
|
bias: bool = True,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
scale = math.sqrt(1 / input_dims)
|
||||||
|
self.weight, self.scales, self.biases = mx.quantize(
|
||||||
|
mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(num_experts, output_dims, input_dims),
|
||||||
|
),
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((num_experts, output_dims))
|
||||||
|
|
||||||
|
self.group_size = group_size
|
||||||
|
self.bits = bits
|
||||||
|
|
||||||
|
# Freeze this model's parameters
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
|
def unfreeze(self, *args, **kwargs):
|
||||||
|
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||||
|
our parameters will remain frozen."""
|
||||||
|
super().unfreeze(*args, **kwargs)
|
||||||
|
self.freeze(recurse=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_dims(self):
|
||||||
|
return self.scales.shape[2] * self.group_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dims(self):
|
||||||
|
return self.weight.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_experts(self):
|
||||||
|
return self.weight.shape[0]
|
||||||
|
|
||||||
|
def __call__(self, x, indices):
|
||||||
|
x = mx.gather_qmm(
|
||||||
|
x,
|
||||||
|
self["weight"],
|
||||||
|
self["scales"],
|
||||||
|
self["biases"],
|
||||||
|
rhs_indices=indices,
|
||||||
|
transpose=True,
|
||||||
|
group_size=self.group_size,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
||||||
|
if "bias" in self:
|
||||||
|
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
scale = math.sqrt(1 / input_dims)
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(num_experts, output_dims, input_dims),
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((num_experts, output_dims))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_dims(self):
|
||||||
|
return self.weight.shape[2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dims(self):
|
||||||
|
return self.weight.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_experts(self):
|
||||||
|
return self.weight.shape[0]
|
||||||
|
|
||||||
|
def __call__(self, x, indices):
|
||||||
|
x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||||
|
if "bias" in self:
|
||||||
|
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||||
|
num_experts, output_dims, input_dims = self.weight.shape
|
||||||
|
ql = QuantizedSwitchLinear(
|
||||||
|
input_dims, output_dims, num_experts, False, group_size, bits
|
||||||
|
)
|
||||||
|
ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
|
||||||
|
if "bias" in self:
|
||||||
|
ql.bias = self.bias
|
||||||
|
return ql
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchGLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
hidden_dims: int,
|
||||||
|
num_experts: int,
|
||||||
|
activation=nn.silu,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||||
|
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||||
|
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def __call__(self, x, indices) -> mx.array:
|
||||||
|
x = mx.expand_dims(x, (-2, -3))
|
||||||
|
|
||||||
|
x_up = self.up_proj(x, indices)
|
||||||
|
x_gate = self.gate_proj(x, indices)
|
||||||
|
x = self.down_proj(self.activation(x_gate) * x_up, indices)
|
||||||
|
|
||||||
|
return x.squeeze(-2)
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
hidden_dims: int,
|
||||||
|
num_experts: int,
|
||||||
|
activation=nn.gelu_approx,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||||
|
self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def __call__(self, x, indices) -> mx.array:
|
||||||
|
x = mx.expand_dims(x, (-2, -3))
|
||||||
|
|
||||||
|
x = self.fc1(x, indices)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.fc2(x, indices)
|
||||||
|
|
||||||
|
return x.squeeze(-2)
|
||||||
1
llms/mlx_lm/py.typed
Normal file
1
llms/mlx_lm/py.typed
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
6
llms/mlx_lm/requirements.txt
Normal file
6
llms/mlx_lm/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
mlx>=0.14.1
|
||||||
|
numpy
|
||||||
|
transformers>=4.39.3
|
||||||
|
protobuf
|
||||||
|
pyyaml
|
||||||
|
jinja2
|
||||||
34
llms/mlx_lm/sample_utils.py
Normal file
34
llms/mlx_lm/sample_utils.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
||||||
|
"""
|
||||||
|
Apply top-p (nucleus) sampling to logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: The logits from the model's output.
|
||||||
|
top_p: The cumulative probability threshold for top-p filtering.
|
||||||
|
temperature: Temperature parameter for softmax distribution reshaping.
|
||||||
|
Returns:
|
||||||
|
token selected based on the top-p criterion.
|
||||||
|
"""
|
||||||
|
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||||
|
probs = mx.softmax(logits / temperature, axis=-1)
|
||||||
|
|
||||||
|
# sort probs in ascending order
|
||||||
|
sorted_indices = mx.argsort(probs, axis=-1)
|
||||||
|
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||||
|
|
||||||
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||||
|
|
||||||
|
# select tokens with cumulative probs below threshold
|
||||||
|
top_probs = mx.where(
|
||||||
|
cumulative_probs > 1 - top_p,
|
||||||
|
sorted_probs,
|
||||||
|
mx.zeros_like(sorted_probs),
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||||
|
token = sorted_indices.squeeze(0)[sorted_token]
|
||||||
|
|
||||||
|
return token
|
||||||
551
llms/mlx_lm/server.py
Normal file
551
llms/mlx_lm/server.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import warnings
|
||||||
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
from typing import List, Literal, NamedTuple, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .tokenizer_utils import TokenizerWrapper
|
||||||
|
from .utils import generate_step, load
|
||||||
|
|
||||||
|
|
||||||
|
class StopCondition(NamedTuple):
|
||||||
|
stop_met: bool
|
||||||
|
trim_length: int
|
||||||
|
|
||||||
|
|
||||||
|
def stopping_criteria(
|
||||||
|
tokens: List[int],
|
||||||
|
stop_id_sequences: List[List[int]],
|
||||||
|
eos_token_id: Union[int, None],
|
||||||
|
) -> StopCondition:
|
||||||
|
"""
|
||||||
|
Determines whether the token generation should stop based on predefined conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (List[int]): The current sequence of generated tokens.
|
||||||
|
stop_id_sequences (List[List[[int]]): A list of integer lists, each representing a sequence of token IDs.
|
||||||
|
If the end of the `tokens` list matches any of these sequences, the generation should stop.
|
||||||
|
eos_token_id (Union[int, None]): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this,
|
||||||
|
the generation should stop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`)
|
||||||
|
and how many tokens should be trimmed from the end if it has (`trim_length`).
|
||||||
|
"""
|
||||||
|
if tokens and tokens[-1] == eos_token_id:
|
||||||
|
return StopCondition(stop_met=True, trim_length=1)
|
||||||
|
|
||||||
|
for stop_ids in stop_id_sequences:
|
||||||
|
if len(tokens) >= len(stop_ids):
|
||||||
|
if tokens[-len(stop_ids) :] == stop_ids:
|
||||||
|
return StopCondition(stop_met=True, trim_length=len(stop_ids))
|
||||||
|
|
||||||
|
return StopCondition(stop_met=False, trim_length=0)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||||
|
default_role_mapping = {
|
||||||
|
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.",
|
||||||
|
"system": "ASSISTANT's RULE: ",
|
||||||
|
"user": "USER: ",
|
||||||
|
"assistant": "ASSISTANT: ",
|
||||||
|
"stop": "\n",
|
||||||
|
}
|
||||||
|
role_mapping = role_mapping if role_mapping is not None else default_role_mapping
|
||||||
|
|
||||||
|
prompt = ""
|
||||||
|
for line in messages:
|
||||||
|
role_prefix = role_mapping.get(line["role"], "")
|
||||||
|
stop = role_mapping.get("stop", "")
|
||||||
|
content = line.get("content", "")
|
||||||
|
prompt += f"{role_prefix}{content}{stop}"
|
||||||
|
|
||||||
|
prompt += role_mapping.get("assistant", "")
|
||||||
|
return prompt.rstrip()
|
||||||
|
|
||||||
|
|
||||||
|
class APIHandler(BaseHTTPRequestHandler):
|
||||||
|
def __init__(self, model: nn.Module, tokenizer: TokenizerWrapper, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create static request specific metadata
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.created = int(time.time())
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _set_cors_headers(self):
|
||||||
|
self.send_header("Access-Control-Allow-Origin", "*")
|
||||||
|
self.send_header("Access-Control-Allow-Methods", "*")
|
||||||
|
self.send_header("Access-Control-Allow-Headers", "*")
|
||||||
|
|
||||||
|
def _set_completion_headers(self, status_code: int = 200):
|
||||||
|
self.send_response(status_code)
|
||||||
|
self.send_header("Content-type", "application/json")
|
||||||
|
self._set_cors_headers()
|
||||||
|
|
||||||
|
def _set_stream_headers(self, status_code: int = 200):
|
||||||
|
self.send_response(status_code)
|
||||||
|
self.send_header("Content-type", "text/event-stream")
|
||||||
|
self.send_header("Cache-Control", "no-cache")
|
||||||
|
self._set_cors_headers()
|
||||||
|
|
||||||
|
def do_OPTIONS(self):
|
||||||
|
self._set_completion_headers(204)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def do_POST(self):
|
||||||
|
"""
|
||||||
|
Respond to a POST request from a client.
|
||||||
|
"""
|
||||||
|
endpoints = {
|
||||||
|
"/v1/completions": self.handle_text_completions,
|
||||||
|
"/v1/chat/completions": self.handle_chat_completions,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.path not in endpoints:
|
||||||
|
self._set_completion_headers(404)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(b"Not Found")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fetch and parse request body
|
||||||
|
content_length = int(self.headers["Content-Length"])
|
||||||
|
raw_body = self.rfile.read(content_length)
|
||||||
|
self.body = json.loads(raw_body.decode())
|
||||||
|
indent = "\t" # Backslashes can't be inside of f-strings
|
||||||
|
logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}")
|
||||||
|
assert isinstance(
|
||||||
|
self.body, dict
|
||||||
|
), f"Request should be dict, but got {type(self.body)}"
|
||||||
|
|
||||||
|
# Extract request parameters from the body
|
||||||
|
self.stream = self.body.get("stream", False)
|
||||||
|
self.requested_model = self.body.get("model", "default_model")
|
||||||
|
self.max_tokens = self.body.get("max_tokens", 100)
|
||||||
|
self.temperature = self.body.get("temperature", 1.0)
|
||||||
|
self.top_p = self.body.get("top_p", 1.0)
|
||||||
|
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||||
|
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||||
|
self.logit_bias = self.body.get("logit_bias", None)
|
||||||
|
|
||||||
|
self.validate_model_parameters()
|
||||||
|
|
||||||
|
# Get stop id sequences, if provided
|
||||||
|
stop_words = self.body.get("stop", [])
|
||||||
|
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||||
|
stop_id_sequences = [
|
||||||
|
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
||||||
|
for stop_word in stop_words
|
||||||
|
]
|
||||||
|
|
||||||
|
# Send header type
|
||||||
|
(
|
||||||
|
self._set_stream_headers(200)
|
||||||
|
if self.stream
|
||||||
|
else self._set_completion_headers(200)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call endpoint specific method
|
||||||
|
prompt = endpoints[self.path]()
|
||||||
|
|
||||||
|
# Call method based on response type
|
||||||
|
method = self.handle_stream if self.stream else self.handle_completion
|
||||||
|
method(prompt, stop_id_sequences)
|
||||||
|
|
||||||
|
def validate_model_parameters(self):
|
||||||
|
"""
|
||||||
|
Validate the model parameters passed in the request for the correct types and values.
|
||||||
|
"""
|
||||||
|
if not isinstance(self.stream, bool):
|
||||||
|
raise ValueError("stream must be a boolean")
|
||||||
|
|
||||||
|
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
|
||||||
|
raise ValueError("max_tokens must be a non-negative integer")
|
||||||
|
|
||||||
|
if not isinstance(self.temperature, float) or self.temperature < 0:
|
||||||
|
raise ValueError("temperature must be a non-negative float")
|
||||||
|
|
||||||
|
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1:
|
||||||
|
raise ValueError("top_p must be a float between 0 and 1")
|
||||||
|
|
||||||
|
if (
|
||||||
|
not isinstance(self.repetition_penalty, float)
|
||||||
|
or self.repetition_penalty < 0
|
||||||
|
):
|
||||||
|
raise ValueError("repetition_penalty must be a non-negative float")
|
||||||
|
|
||||||
|
if (
|
||||||
|
not isinstance(self.repetition_context_size, int)
|
||||||
|
or self.repetition_context_size < 0
|
||||||
|
):
|
||||||
|
raise ValueError("repetition_context_size must be a non-negative integer")
|
||||||
|
|
||||||
|
if self.logit_bias is not None:
|
||||||
|
if not isinstance(self.logit_bias, dict):
|
||||||
|
raise ValueError("logit_bias must be a dict of int to float")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.logit_bias = {int(k): v for k, v in self.logit_bias.items()}
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("logit_bias must be a dict of int to float")
|
||||||
|
|
||||||
|
if not isinstance(self.requested_model, str):
|
||||||
|
raise ValueError("model must be a string")
|
||||||
|
|
||||||
|
def generate_response(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
finish_reason: Union[Literal["length", "stop"], None],
|
||||||
|
prompt_token_count: Optional[int] = None,
|
||||||
|
completion_token_count: Optional[int] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Generate a single response packet based on response type (stream or not), completion type and parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Text generated by model
|
||||||
|
finish_reason (Union[Literal["length", "stop"], None]):
|
||||||
|
The reason the response is being sent: "length", "stop" or None
|
||||||
|
prompt_token_count (Optional[int]):
|
||||||
|
The amount of tokens in the prompt,
|
||||||
|
used to populate the "usage" field (not used when stream)
|
||||||
|
completion_token_count (Optional[int]):
|
||||||
|
The amount of tokens in the response,
|
||||||
|
used to populate the "usage" field (not used when stream)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the response, imitating OpenAI's API
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Static response
|
||||||
|
response = {
|
||||||
|
"id": self.request_id,
|
||||||
|
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||||
|
"object": self.object_type,
|
||||||
|
"model": self.requested_model,
|
||||||
|
"created": self.created,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.stream:
|
||||||
|
if not (
|
||||||
|
isinstance(prompt_token_count, int)
|
||||||
|
and isinstance(completion_token_count, int)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Response type is complete, but token counts not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
response["usage"] = {
|
||||||
|
"prompt_tokens": prompt_token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": prompt_token_count + completion_token_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
choice = response["choices"][0]
|
||||||
|
|
||||||
|
# Add dynamic response
|
||||||
|
if self.object_type.startswith("chat.completion"):
|
||||||
|
key_name = "delta" if self.stream else "message"
|
||||||
|
choice[key_name] = {"role": "assistant", "content": text}
|
||||||
|
elif self.object_type == "text_completion":
|
||||||
|
choice.update(text=text)
|
||||||
|
else:
|
||||||
|
ValueError(f"Unsupported response type: {self.object_type}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def handle_completion(
|
||||||
|
self,
|
||||||
|
prompt: mx.array,
|
||||||
|
stop_id_sequences: List[List[int]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a response to a prompt and send it to the client in a single batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||||
|
stop_id_sequences (List[List[int]]):
|
||||||
|
A list of stop words passed to the stopping_criteria function
|
||||||
|
"""
|
||||||
|
detokenizer = self.tokenizer.detokenizer
|
||||||
|
detokenizer.reset()
|
||||||
|
tokens = []
|
||||||
|
finish_reason = "length"
|
||||||
|
stop_sequence_suffix = None
|
||||||
|
logging.debug(f"Starting completion:")
|
||||||
|
for (token, _), _ in zip(
|
||||||
|
generate_step(
|
||||||
|
prompt=prompt,
|
||||||
|
model=self.model,
|
||||||
|
temp=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
repetition_context_size=self.repetition_context_size,
|
||||||
|
logit_bias=self.logit_bias,
|
||||||
|
),
|
||||||
|
range(self.max_tokens),
|
||||||
|
):
|
||||||
|
detokenizer.add_token(token)
|
||||||
|
logging.debug(detokenizer.text)
|
||||||
|
tokens.append(token)
|
||||||
|
stop_condition = stopping_criteria(
|
||||||
|
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
if stop_condition.stop_met:
|
||||||
|
finish_reason = "stop"
|
||||||
|
if stop_condition.trim_length:
|
||||||
|
stop_sequence_suffix = self.tokenizer.decode(
|
||||||
|
tokens[-stop_condition.trim_length :]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
detokenizer.finalize()
|
||||||
|
text = (
|
||||||
|
detokenizer.text
|
||||||
|
if stop_sequence_suffix is None
|
||||||
|
else detokenizer.text[: -len(stop_sequence_suffix)]
|
||||||
|
)
|
||||||
|
response = self.generate_response(text, finish_reason, len(prompt), len(tokens))
|
||||||
|
|
||||||
|
response_json = json.dumps(response).encode()
|
||||||
|
indent = "\t" # Backslashes can't be inside of f-strings
|
||||||
|
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
||||||
|
|
||||||
|
# Send an additional Content-Length header when it is known
|
||||||
|
self.send_header("Content-Length", str(len(response_json)))
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
self.wfile.write(response_json)
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
|
def handle_stream(
|
||||||
|
self,
|
||||||
|
prompt: mx.array,
|
||||||
|
stop_id_sequences: List[List[int]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||||
|
stop_id_sequences (List[List[int]]):
|
||||||
|
A list of stop words passed to the stopping_criteria function
|
||||||
|
"""
|
||||||
|
# No additional headers are needed, call end_headers
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
detokenizer = self.tokenizer.detokenizer
|
||||||
|
detokenizer.reset()
|
||||||
|
tokens = []
|
||||||
|
|
||||||
|
max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
|
||||||
|
# Buffer to store the last `max_stop_id_sequence_len` tokens
|
||||||
|
# to check for stop conditions before writing to the stream.
|
||||||
|
stop_sequence_buffer = []
|
||||||
|
stop_sequence_suffix = None
|
||||||
|
logging.debug(f"Starting stream:")
|
||||||
|
for (token, _), _ in zip(
|
||||||
|
generate_step(
|
||||||
|
prompt=prompt,
|
||||||
|
model=self.model,
|
||||||
|
temp=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
repetition_context_size=self.repetition_context_size,
|
||||||
|
),
|
||||||
|
range(self.max_tokens),
|
||||||
|
):
|
||||||
|
detokenizer.add_token(token)
|
||||||
|
logging.debug(detokenizer.text)
|
||||||
|
tokens.append(token)
|
||||||
|
stop_sequence_buffer.append(token)
|
||||||
|
|
||||||
|
# Continue generating tokens until buffer is as large as the longest stop_id_sequence
|
||||||
|
if len(stop_sequence_buffer) < max_stop_id_sequence_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stop_condition = stopping_criteria(
|
||||||
|
tokens,
|
||||||
|
stop_id_sequences,
|
||||||
|
self.tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
if stop_condition.stop_met:
|
||||||
|
if stop_condition.trim_length:
|
||||||
|
stop_sequence_suffix = self.tokenizer.decode(
|
||||||
|
tokens[-stop_condition.trim_length :]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
new_text = detokenizer.last_segment
|
||||||
|
response = self.generate_response(new_text, None)
|
||||||
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
stop_sequence_buffer = []
|
||||||
|
|
||||||
|
# check is there any remaining text to send
|
||||||
|
if stop_sequence_buffer:
|
||||||
|
next_chunk = (
|
||||||
|
detokenizer.last_segment
|
||||||
|
if stop_sequence_suffix is None
|
||||||
|
else detokenizer.last_segment[: -len(stop_sequence_suffix)]
|
||||||
|
)
|
||||||
|
response = self.generate_response(next_chunk, "length")
|
||||||
|
|
||||||
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
|
self.wfile.write("data: [DONE]\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
|
def handle_chat_completions(self) -> mx.array:
|
||||||
|
"""
|
||||||
|
Handle a chat completion request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mx.array: A mx.array of the tokenized prompt from the request body
|
||||||
|
"""
|
||||||
|
body = self.body
|
||||||
|
assert "messages" in body, "Request did not contain messages"
|
||||||
|
|
||||||
|
# Determine response type
|
||||||
|
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
|
self.object_type = (
|
||||||
|
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(self.tokenizer, "apply_chat_template")
|
||||||
|
and self.tokenizer.chat_template
|
||||||
|
):
|
||||||
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
|
body["messages"],
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||||
|
prompt = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
return mx.array(prompt)
|
||||||
|
|
||||||
|
def handle_text_completions(self) -> mx.array:
|
||||||
|
"""
|
||||||
|
Handle a text completion request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mx.array: A mx.array of the tokenized prompt from the request body
|
||||||
|
"""
|
||||||
|
# Determine response type
|
||||||
|
self.request_id = f"cmpl-{uuid.uuid4()}"
|
||||||
|
self.object_type = "text_completion"
|
||||||
|
|
||||||
|
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||||
|
prompt_text = self.body["prompt"]
|
||||||
|
|
||||||
|
prompt = self.tokenizer.encode(prompt_text)
|
||||||
|
return mx.array(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
server_class=HTTPServer,
|
||||||
|
handler_class=APIHandler,
|
||||||
|
):
|
||||||
|
server_address = (host, port)
|
||||||
|
httpd = server_class(
|
||||||
|
server_address,
|
||||||
|
lambda *args, **kwargs: handler_class(model, tokenizer, *args, **kwargs),
|
||||||
|
)
|
||||||
|
warnings.warn(
|
||||||
|
"mlx_lm.server is not recommended for production as "
|
||||||
|
"it only implements basic security checks."
|
||||||
|
)
|
||||||
|
logging.info(f"Starting httpd at {host} on port {port}...")
|
||||||
|
httpd.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="MLX Http Server.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The path to the MLX model weights, tokenizer, and config",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-path",
|
||||||
|
type=str,
|
||||||
|
help="Optional path for the trained adapter weights and config.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="127.0.0.1",
|
||||||
|
help="Host for the HTTP server (default: 127.0.0.1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8080,
|
||||||
|
help="Port for the HTTP server (default: 8080)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable trusting remote code for tokenizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
type=str,
|
||||||
|
default="INFO",
|
||||||
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
|
help="Set the logging level (default: INFO)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-limit-gb",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Set the MLX cache limit in GB",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, args.log_level.upper(), None),
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.cache_limit_gb is not None:
|
||||||
|
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
|
||||||
|
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
# Building tokenizer_config
|
||||||
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
|
|
||||||
|
model, tokenizer = load(
|
||||||
|
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||||
|
)
|
||||||
|
run(args.host, args.port, model, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
330
llms/mlx_lm/tokenizer_utils.py
Normal file
330
llms/mlx_lm/tokenizer_utils.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
import json
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
REPLACEMENT_CHAR = "\ufffd"
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_space(x):
|
||||||
|
if x and x[0] == " ":
|
||||||
|
return x[1:]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingDetokenizer:
|
||||||
|
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||||
|
|
||||||
|
Example usage is as follows:
|
||||||
|
|
||||||
|
detokenizer = ...
|
||||||
|
|
||||||
|
# Reset the tokenizer state
|
||||||
|
detokenizer.reset()
|
||||||
|
|
||||||
|
for token in generate(...):
|
||||||
|
detokenizer.add_token(token.item())
|
||||||
|
|
||||||
|
# Contains the whole text so far. Some tokens may not be included
|
||||||
|
# since it contains whole words usually.
|
||||||
|
detokenizer.text
|
||||||
|
|
||||||
|
# Contains the printable segment (usually a word) since the last
|
||||||
|
# time it was accessed
|
||||||
|
detokenizer.last_segment
|
||||||
|
|
||||||
|
# Contains all the tokens added so far
|
||||||
|
detokenizer.tokens
|
||||||
|
|
||||||
|
# Make sure that we detokenize any remaining tokens
|
||||||
|
detokenizer.finalize()
|
||||||
|
|
||||||
|
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("text", "tokens", "offset")
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def add_token(self, token):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_segment(self):
|
||||||
|
"""Return the last segment of readable text since last time this property was accessed."""
|
||||||
|
text = self.text
|
||||||
|
if text and text[-1] != REPLACEMENT_CHAR:
|
||||||
|
segment = text[self.offset :]
|
||||||
|
self.offset = len(text)
|
||||||
|
return segment
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||||
|
implementation and should work with every tokenizer.
|
||||||
|
|
||||||
|
Its complexity is O(T^2) where T is the longest line since it will
|
||||||
|
repeatedly detokenize the same tokens until a new line is generated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._tokenizer.decode([0])
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.offset = 0
|
||||||
|
self._tokens = []
|
||||||
|
self._text = ""
|
||||||
|
self._current_tokens = []
|
||||||
|
self._current_text = ""
|
||||||
|
|
||||||
|
def add_token(self, token):
|
||||||
|
self._current_tokens.append(token)
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
self._tokens.extend(self._current_tokens)
|
||||||
|
self._text += self._tokenizer.decode(self._current_tokens)
|
||||||
|
self._current_tokens = []
|
||||||
|
self._current_text = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self):
|
||||||
|
if self._current_tokens:
|
||||||
|
self._current_text = self._tokenizer.decode(self._current_tokens)
|
||||||
|
if self._current_text and self._current_text[-1] == "\n":
|
||||||
|
self._tokens.extend(self._current_tokens)
|
||||||
|
self._text += self._current_text
|
||||||
|
self._current_tokens.clear()
|
||||||
|
self._current_text = ""
|
||||||
|
return self._text + self._current_text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokens(self):
|
||||||
|
return self._tokens
|
||||||
|
|
||||||
|
|
||||||
|
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""A streaming detokenizer for SPM models.
|
||||||
|
|
||||||
|
It adds tokens to the text if the next token starts with the special SPM
|
||||||
|
underscore which results in linear complexity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, trim_space=True):
|
||||||
|
self.trim_space = trim_space
|
||||||
|
|
||||||
|
# Extract the tokens in a list from id to text
|
||||||
|
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||||
|
for value, tokenid in tokenizer.vocab.items():
|
||||||
|
self.tokenmap[tokenid] = value
|
||||||
|
|
||||||
|
# Replace bytes with their value
|
||||||
|
for i in range(len(self.tokenmap)):
|
||||||
|
if self.tokenmap[i].startswith("<0x"):
|
||||||
|
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.offset = 0
|
||||||
|
self._unflushed = ""
|
||||||
|
self.text = ""
|
||||||
|
self.tokens = []
|
||||||
|
|
||||||
|
def add_token(self, token):
|
||||||
|
v = self.tokenmap[token]
|
||||||
|
if v[0] == "\u2581":
|
||||||
|
if self.text or not self.trim_space:
|
||||||
|
self.text += self._unflushed.replace("\u2581", " ")
|
||||||
|
else:
|
||||||
|
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||||
|
self._unflushed = v
|
||||||
|
else:
|
||||||
|
self._unflushed += v
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
if self.text or not self.trim_space:
|
||||||
|
self.text += self._unflushed.replace("\u2581", " ")
|
||||||
|
else:
|
||||||
|
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
||||||
|
self._unflushed = ""
|
||||||
|
|
||||||
|
|
||||||
|
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||||
|
"""A streaming detokenizer for OpenAI style BPE models.
|
||||||
|
|
||||||
|
It adds tokens to the text if the next token starts with a space similar to
|
||||||
|
the SPM detokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_byte_decoder = None
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, trim_space=False):
|
||||||
|
self.trim_space = trim_space
|
||||||
|
|
||||||
|
# Extract the tokens in a list from id to text
|
||||||
|
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||||
|
for value, tokenid in tokenizer.vocab.items():
|
||||||
|
self.tokenmap[tokenid] = value
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
# Make the BPE byte decoder from
|
||||||
|
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
|
self.make_byte_decoder()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.offset = 0
|
||||||
|
self._unflushed = ""
|
||||||
|
self.text = ""
|
||||||
|
self.tokens = []
|
||||||
|
|
||||||
|
def add_token(self, token):
|
||||||
|
v = self.tokenmap[token]
|
||||||
|
# if the token starts with space
|
||||||
|
if self._byte_decoder[v[0]] == 32:
|
||||||
|
current_text = bytearray(
|
||||||
|
self._byte_decoder[c] for c in self._unflushed
|
||||||
|
).decode("utf-8")
|
||||||
|
if self.text or not self.trim_space:
|
||||||
|
self.text += current_text
|
||||||
|
else:
|
||||||
|
self.text += _remove_space(current_text)
|
||||||
|
self._unflushed = v
|
||||||
|
else:
|
||||||
|
self._unflushed += v
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
if self.text or not self.trim_space:
|
||||||
|
self.text += current_text
|
||||||
|
else:
|
||||||
|
self.text += _remove_space(current_text)
|
||||||
|
self._unflushed = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_byte_decoder(cls):
|
||||||
|
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||||
|
if cls._byte_decoder is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
char_to_bytes = {}
|
||||||
|
limits = [
|
||||||
|
0,
|
||||||
|
ord("!"),
|
||||||
|
ord("~") + 1,
|
||||||
|
ord("¡"),
|
||||||
|
ord("¬") + 1,
|
||||||
|
ord("®"),
|
||||||
|
ord("ÿ") + 1,
|
||||||
|
]
|
||||||
|
n = 0
|
||||||
|
for i, (start, stop) in enumerate(zip(limits, limits[1:])):
|
||||||
|
if i % 2 == 0:
|
||||||
|
for b in range(start, stop):
|
||||||
|
char_to_bytes[chr(2**8 + n)] = b
|
||||||
|
n += 1
|
||||||
|
else:
|
||||||
|
for b in range(start, stop):
|
||||||
|
char_to_bytes[chr(b)] = b
|
||||||
|
cls._byte_decoder = char_to_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerWrapper:
|
||||||
|
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
||||||
|
|
||||||
|
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||||
|
huggingface tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._detokenizer = detokenizer_class(tokenizer)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr == "detokenizer":
|
||||||
|
return self._detokenizer
|
||||||
|
else:
|
||||||
|
return getattr(self._tokenizer, attr)
|
||||||
|
|
||||||
|
|
||||||
|
def _match(a, b):
|
||||||
|
if type(a) != type(b):
|
||||||
|
return False
|
||||||
|
if isinstance(a, dict):
|
||||||
|
return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a)
|
||||||
|
if isinstance(a, list):
|
||||||
|
return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b))
|
||||||
|
|
||||||
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
|
def _is_spm_decoder(decoder):
|
||||||
|
_target_description = {
|
||||||
|
"type": "Sequence",
|
||||||
|
"decoders": [
|
||||||
|
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
||||||
|
{"type": "ByteFallback"},
|
||||||
|
{"type": "Fuse"},
|
||||||
|
{"type": "Strip", "content": " ", "start": 1, "stop": 0},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return _match(_target_description, decoder)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_spm_decoder_no_space(decoder):
|
||||||
|
_target_description = {
|
||||||
|
"type": "Sequence",
|
||||||
|
"decoders": [
|
||||||
|
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
||||||
|
{"type": "ByteFallback"},
|
||||||
|
{"type": "Fuse"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return _match(_target_description, decoder)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_bpe_decoder(decoder):
|
||||||
|
_target_description = {
|
||||||
|
"type": "ByteLevel",
|
||||||
|
"add_prefix_space": False,
|
||||||
|
"trim_offsets": False,
|
||||||
|
"use_regex": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
return _match(_target_description, decoder)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
||||||
|
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||||
|
detokenizer to use.
|
||||||
|
|
||||||
|
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
||||||
|
a Hugging Face repo ID.
|
||||||
|
"""
|
||||||
|
detokenizer_class = NaiveStreamingDetokenizer
|
||||||
|
|
||||||
|
tokenizer_file = model_path / "tokenizer.json"
|
||||||
|
if tokenizer_file.exists():
|
||||||
|
with open(tokenizer_file, "r") as fid:
|
||||||
|
tokenizer_content = json.load(fid)
|
||||||
|
if "decoder" in tokenizer_content:
|
||||||
|
if _is_spm_decoder(tokenizer_content["decoder"]):
|
||||||
|
detokenizer_class = SPMStreamingDetokenizer
|
||||||
|
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
|
||||||
|
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
|
||||||
|
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||||
|
detokenizer_class = BPEStreamingDetokenizer
|
||||||
|
|
||||||
|
return TokenizerWrapper(
|
||||||
|
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
||||||
|
detokenizer_class,
|
||||||
|
)
|
||||||
2
llms/mlx_lm/tuner/__init__.py
Normal file
2
llms/mlx_lm/tuner/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .trainer import TrainingArgs, evaluate, train
|
||||||
|
from .utils import linear_to_lora_layers
|
||||||
104
llms/mlx_lm/tuner/datasets.py
Normal file
104
llms/mlx_lm/tuner/datasets.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset:
|
||||||
|
"""
|
||||||
|
Light-weight wrapper to hold lines from a jsonl file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
with open(path, "r") as fid:
|
||||||
|
self._data = [json.loads(l) for l in fid]
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
return self._data[idx]["text"]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self._data is None:
|
||||||
|
return 0
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDataset(Dataset):
|
||||||
|
"""
|
||||||
|
A dataset for chat data in the format of {"messages": [...]}
|
||||||
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
||||||
|
super().__init__(path)
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
messages = self._data[idx]["messages"]
|
||||||
|
text = self._tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionsDataset(Dataset):
|
||||||
|
"""
|
||||||
|
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||||
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
||||||
|
super().__init__(path)
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
data = self._data[idx]
|
||||||
|
text = self._tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": data["prompt"]},
|
||||||
|
{"role": "assistant", "content": data["completion"]},
|
||||||
|
],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
||||||
|
# Return empty dataset for non-existent paths
|
||||||
|
if not path.exists():
|
||||||
|
return []
|
||||||
|
with open(path, "r") as fid:
|
||||||
|
first_line = next(fid)
|
||||||
|
first_obj = json.loads(first_line)
|
||||||
|
if "messages" in first_obj:
|
||||||
|
return ChatDataset(path, tokenizer)
|
||||||
|
elif "prompt" in first_obj and "completion" in first_obj:
|
||||||
|
return CompletionsDataset(path, tokenizer)
|
||||||
|
elif "text" in first_obj:
|
||||||
|
return Dataset(path)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported data format, check the supported formats here:\n"
|
||||||
|
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
|
names = ("train", "valid", "test")
|
||||||
|
data_path = Path(args.data)
|
||||||
|
train, valid, test = [
|
||||||
|
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
|
||||||
|
]
|
||||||
|
if args.train and len(train) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||||
|
)
|
||||||
|
if args.train and len(valid) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||||
|
)
|
||||||
|
if args.test and len(test) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Test set not found or empty. Must provide test set for evaluation."
|
||||||
|
)
|
||||||
|
return train, valid, test
|
||||||
@@ -6,40 +6,47 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class LoRALinear(nn.Module):
|
class DoRALinear(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_base(
|
def from_linear(
|
||||||
linear: nn.Linear,
|
linear: nn.Linear,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 1.0,
|
scale: float = 20.0,
|
||||||
):
|
):
|
||||||
|
# TODO support quantized weights in DoRALinear
|
||||||
output_dims, input_dims = linear.weight.shape
|
output_dims, input_dims = linear.weight.shape
|
||||||
lora_lin = LoRALinear(
|
if isinstance(linear, nn.QuantizedLinear):
|
||||||
|
raise ValueError("DoRALinear does not yet support quantization.")
|
||||||
|
dora_lin = DoRALinear(
|
||||||
input_dims=input_dims,
|
input_dims=input_dims,
|
||||||
output_dims=output_dims,
|
output_dims=output_dims,
|
||||||
r=r,
|
r=r,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
lora_lin.linear = linear
|
dora_lin.linear = linear
|
||||||
return lora_lin
|
return dora_lin
|
||||||
|
|
||||||
def fuse(self):
|
def to_linear(self, de_quantize: bool = False):
|
||||||
linear = self.linear
|
linear = self.linear
|
||||||
bias = "bias" in linear
|
bias = "bias" in linear
|
||||||
weight = linear.weight
|
weight = linear.weight
|
||||||
|
|
||||||
|
# Use the same type as the linear weight if not quantized
|
||||||
dtype = weight.dtype
|
dtype = weight.dtype
|
||||||
|
|
||||||
output_dims, input_dims = weight.shape
|
output_dims, input_dims = weight.shape
|
||||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||||
|
|
||||||
lora_b = self.scale * self.lora_b.T
|
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||||
lora_a = self.lora_a.T
|
lora_a = self.lora_a.T.astype(dtype)
|
||||||
fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
|
weight = weight + lora_b @ lora_a
|
||||||
|
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
|
||||||
|
fused_linear.weight = norm_scale[:, None] * weight
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
fused_linear.bias = linear.bias
|
fused_linear.bias = linear.bias
|
||||||
|
|
||||||
return fused_linear
|
return fused_linear
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -48,14 +55,13 @@ class LoRALinear(nn.Module):
|
|||||||
output_dims: int,
|
output_dims: int,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 1.0,
|
scale: float = 20.0,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Regular linear layer weights
|
# Regular linear layer weights
|
||||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(p=dropout)
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
# Scale for low-rank update
|
# Scale for low-rank update
|
||||||
@@ -69,8 +75,21 @@ class LoRALinear(nn.Module):
|
|||||||
shape=(input_dims, r),
|
shape=(input_dims, r),
|
||||||
)
|
)
|
||||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||||
|
self.m = mx.linalg.norm(self.linear.weight, axis=1)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
y = self.linear(x)
|
# Regular LoRA (without a bias)
|
||||||
|
y = x @ self.linear.weight.T
|
||||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||||
return y + (self.scale * z).astype(x.dtype)
|
out = y + (self.scale * z).astype(x.dtype)
|
||||||
|
|
||||||
|
# Compute the norm of the adapted weights
|
||||||
|
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||||
|
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||||
|
|
||||||
|
# Remove the norm and scale by the learned magnitude
|
||||||
|
out = (self.m / denom) * out
|
||||||
|
|
||||||
|
if "bias" in self.linear:
|
||||||
|
out = out + self.linear.bias
|
||||||
|
return out
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user