mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-10 14:09:02 +08:00
remove mlx lm (#1353)
This commit is contained in:
@@ -1,47 +0,0 @@
|
||||
# Contributing to MLX LM
|
||||
|
||||
Below are some tips to port LLMs available on Hugging Face to MLX.
|
||||
|
||||
Before starting checkout the [general contribution
|
||||
guidelines](https://github.com/ml-explore/mlx-examples/blob/main/CONTRIBUTING.md).
|
||||
|
||||
Next, from this directory, do an editable install:
|
||||
|
||||
```shell
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then check if the model has weights in the
|
||||
[safetensors](https://huggingface.co/docs/safetensors/index) format. If not
|
||||
[follow instructions](https://huggingface.co/spaces/safetensors/convert) to
|
||||
convert it.
|
||||
|
||||
After that, add the model file to the
|
||||
[`mlx_lm/models`](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/models)
|
||||
directory. You can see other examples there. We recommend starting from a model
|
||||
that is similar to the model you are porting.
|
||||
|
||||
Make sure the name of the new model file is the same as the `model_type` in the
|
||||
`config.json`, for example
|
||||
[starcoder2](https://huggingface.co/bigcode/starcoder2-7b/blob/main/config.json#L17).
|
||||
|
||||
To determine the model layer names, we suggest either:
|
||||
|
||||
- Refer to the Transformers implementation if you are familiar with the
|
||||
codebase.
|
||||
- Load the model weights and check the weight names which will tell you about
|
||||
the model structure.
|
||||
- Look at the names of the weights by inspecting `model.safetensors.index.json`
|
||||
in the Hugging Face repo.
|
||||
|
||||
To add LoRA support edit
|
||||
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
|
||||
|
||||
Finally, add a test for the new modle type to the [model
|
||||
tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py).
|
||||
|
||||
From the `llms/` directory, you can run the tests with:
|
||||
|
||||
```shell
|
||||
python -m unittest discover tests/
|
||||
```
|
||||
@@ -1,2 +0,0 @@
|
||||
include mlx_lm/requirements.txt
|
||||
recursive-include mlx_lm/ *.py
|
||||
300
llms/README.md
300
llms/README.md
@@ -1,300 +1,6 @@
|
||||
# DEPRECATED
|
||||
# MOVE NOTICE
|
||||
|
||||
The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm).
|
||||
|
||||
The package here will be removed soon. Send new contributions and issues to the MLX LM repo.
|
||||
|
||||
## Generate Text with LLMs and MLX
|
||||
|
||||
The easiest way to get started is to install the `mlx-lm` package:
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```sh
|
||||
pip install mlx-lm
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
|
||||
```sh
|
||||
conda install -c conda-forge mlx-lm
|
||||
```
|
||||
|
||||
The `mlx-lm` package also has:
|
||||
|
||||
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||
|
||||
### Quick Start
|
||||
|
||||
To generate text with an LLM use:
|
||||
|
||||
```bash
|
||||
mlx_lm.generate --prompt "Hi!"
|
||||
```
|
||||
|
||||
To chat with an LLM use:
|
||||
|
||||
```bash
|
||||
mlx_lm.chat
|
||||
```
|
||||
|
||||
This will give you a chat REPL that you can use to interact with the LLM. The
|
||||
chat context is preserved during the lifetime of the REPL.
|
||||
|
||||
Commands in `mlx-lm` typically take command line options which let you specify
|
||||
the model, sampling parameters, and more. Use `-h` to see a list of available
|
||||
options for a command, e.g.:
|
||||
|
||||
```bash
|
||||
mlx_lm.generate -h
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
You can use `mlx-lm` as a module:
|
||||
|
||||
```python
|
||||
from mlx_lm import load, generate
|
||||
|
||||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||
|
||||
prompt = "Write a story about Einstein"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
```
|
||||
|
||||
To see a description of all the arguments you can do:
|
||||
|
||||
```
|
||||
>>> help(generate)
|
||||
```
|
||||
|
||||
Check out the [generation
|
||||
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
|
||||
to see how to use the API in more detail.
|
||||
|
||||
The `mlx-lm` package also comes with functionality to quantize and optionally
|
||||
upload models to the Hugging Face Hub.
|
||||
|
||||
You can convert models using the Python API:
|
||||
|
||||
```python
|
||||
from mlx_lm import convert
|
||||
|
||||
repo = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
|
||||
|
||||
convert(repo, quantize=True, upload_repo=upload_repo)
|
||||
```
|
||||
|
||||
This will generate a 4-bit quantized Mistral 7B and upload it to the repo
|
||||
`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
|
||||
converted model in the path `mlx_model` by default.
|
||||
|
||||
To see a description of all the arguments you can do:
|
||||
|
||||
```
|
||||
>>> help(convert)
|
||||
```
|
||||
|
||||
#### Streaming
|
||||
|
||||
For streaming generation, use the `stream_generate` function. This yields
|
||||
a generation response object.
|
||||
|
||||
For example,
|
||||
|
||||
```python
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||
model, tokenizer = load(repo)
|
||||
|
||||
prompt = "Write a story about Einstein"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print(response.text, end="", flush=True)
|
||||
print()
|
||||
```
|
||||
|
||||
#### Sampling
|
||||
|
||||
The `generate` and `stream_generate` functions accept `sampler` and
|
||||
`logits_processors` keyword arguments. A sampler is any callable which accepts
|
||||
a possibly batched logits array and returns an array of sampled tokens. The
|
||||
`logits_processors` must be a list of callables which take the token history
|
||||
and current logits as input and return the processed logits. The logits
|
||||
processors are applied in order.
|
||||
|
||||
Some standard sampling functions and logits processors are provided in
|
||||
`mlx_lm.sample_utils`.
|
||||
|
||||
### Command Line
|
||||
|
||||
You can also use `mlx-lm` from the command line with:
|
||||
|
||||
```
|
||||
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
|
||||
```
|
||||
|
||||
This will download a Mistral 7B model from the Hugging Face Hub and generate
|
||||
text using the given prompt.
|
||||
|
||||
For a full list of options run:
|
||||
|
||||
```
|
||||
mlx_lm.generate --help
|
||||
```
|
||||
|
||||
To quantize a model from the command line run:
|
||||
|
||||
```
|
||||
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
|
||||
```
|
||||
|
||||
For more options run:
|
||||
|
||||
```
|
||||
mlx_lm.convert --help
|
||||
```
|
||||
|
||||
You can upload new models to Hugging Face by specifying `--upload-repo` to
|
||||
`convert`. For example, to upload a quantized Mistral-7B model to the
|
||||
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
|
||||
|
||||
```
|
||||
mlx_lm.convert \
|
||||
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
-q \
|
||||
--upload-repo mlx-community/my-4bit-mistral
|
||||
```
|
||||
|
||||
Models can also be converted and quantized directly in the
|
||||
[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
|
||||
Face Space.
|
||||
|
||||
### Long Prompts and Generations
|
||||
|
||||
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
|
||||
|
||||
- A rotating fixed-size key-value cache.
|
||||
- Prompt caching
|
||||
|
||||
To use the rotating key-value cache pass the argument `--max-kv-size n` where
|
||||
`n` can be any integer. Smaller values like `512` will use very little RAM but
|
||||
result in worse quality. Larger values like `4096` or higher will use more RAM
|
||||
but have better quality.
|
||||
|
||||
Caching prompts can substantially speedup reusing the same long context with
|
||||
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
|
||||
|
||||
```bash
|
||||
cat prompt.txt | mlx_lm.cache_prompt \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
--prompt - \
|
||||
--prompt-cache-file mistral_prompt.safetensors
|
||||
```
|
||||
|
||||
Then use the cached prompt with `mlx_lm.generate`:
|
||||
|
||||
```
|
||||
mlx_lm.generate \
|
||||
--prompt-cache-file mistral_prompt.safetensors \
|
||||
--prompt "\nSummarize the above text."
|
||||
```
|
||||
|
||||
The cached prompt is treated as a prefix to the supplied prompt. Also notice
|
||||
when using a cached prompt, the model to use is read from the cache and need
|
||||
not be supplied explicitly.
|
||||
|
||||
Prompt caching can also be used in the Python API in order to to avoid
|
||||
recomputing the prompt. This is useful in multi-turn dialogues or across
|
||||
requests that use the same context. See the
|
||||
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
|
||||
for more usage details.
|
||||
|
||||
### Supported Models
|
||||
|
||||
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
||||
run is not supported, file an
|
||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||
submit a pull request.
|
||||
|
||||
Here are a few examples of Hugging Face models that work with this example:
|
||||
|
||||
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
|
||||
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
|
||||
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
|
||||
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
|
||||
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
|
||||
- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b)
|
||||
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
|
||||
|
||||
Most
|
||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
|
||||
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
|
||||
and
|
||||
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
|
||||
style models should work out of the box.
|
||||
|
||||
For some models (such as `Qwen` and `plamo`) the tokenizer requires you to
|
||||
enable the `trust_remote_code` option. You can do this by passing
|
||||
`--trust-remote-code` in the command line. If you don't specify the flag
|
||||
explicitly, you will be prompted to trust remote code in the terminal when
|
||||
running the model.
|
||||
|
||||
For `Qwen` models you must also specify the `eos_token`. You can do this by
|
||||
passing `--eos-token "<|endoftext|>"` in the command
|
||||
line.
|
||||
|
||||
These options can also be set in the Python API. For example:
|
||||
|
||||
```python
|
||||
model, tokenizer = load(
|
||||
"qwen/Qwen-7B",
|
||||
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
||||
)
|
||||
```
|
||||
|
||||
### Large Models
|
||||
|
||||
> [!NOTE]
|
||||
This requires macOS 15.0 or higher to work.
|
||||
|
||||
Models which are large relative to the total RAM available on the machine can
|
||||
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
|
||||
occupied by the model and cache. This requires macOS 15 or higher to
|
||||
work.
|
||||
|
||||
If you see the following warning message:
|
||||
|
||||
> [WARNING] Generating with a model that requires ...
|
||||
|
||||
then the model will likely be slow on the given machine. If the model fits in
|
||||
RAM then it can often be sped up by increasing the system wired memory limit.
|
||||
To increase the limit, set the following `sysctl`:
|
||||
|
||||
```bash
|
||||
sudo sysctl iogpu.wired_limit_mb=N
|
||||
```
|
||||
|
||||
The value `N` should be larger than the size of the model in megabytes but
|
||||
smaller than the memory size of the machine.
|
||||
The package has been removed from the MLX Examples repo. Send new contributions
|
||||
and issues to the MLX LM repo.
|
||||
|
||||
@@ -1,392 +0,0 @@
|
||||
# Fine-Tuning with LoRA or QLoRA
|
||||
|
||||
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
|
||||
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
|
||||
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||
|
||||
- Mistral
|
||||
- Llama
|
||||
- Phi2
|
||||
- Mixtral
|
||||
- Qwen2
|
||||
- Gemma
|
||||
- OLMo
|
||||
- MiniCPM
|
||||
- InternLM2
|
||||
|
||||
## Contents
|
||||
|
||||
- [Run](#Run)
|
||||
- [Fine-tune](#Fine-tune)
|
||||
- [Evaluate](#Evaluate)
|
||||
- [Generate](#Generate)
|
||||
- [Fuse](#Fuse)
|
||||
- [Data](#Data)
|
||||
- [Memory Issues](#Memory-Issues)
|
||||
|
||||
## Run
|
||||
|
||||
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora --help
|
||||
```
|
||||
|
||||
Note, in the following the `--model` argument can be any compatible Hugging
|
||||
Face repo or a local path to a converted model.
|
||||
|
||||
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
|
||||
[example YAML](examples/lora_config.yaml). For example:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
If command-line flags are also used, they will override the corresponding
|
||||
values in the config.
|
||||
|
||||
### Fine-tune
|
||||
|
||||
To fine-tune a model use:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--train \
|
||||
--data <path_to_data> \
|
||||
--iters 600
|
||||
```
|
||||
|
||||
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
|
||||
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
|
||||
|
||||
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||
details on the data format see the section on [Data](#Data).
|
||||
|
||||
For example, to fine-tune a Mistral 7B you can use `--model
|
||||
mistralai/Mistral-7B-v0.1`.
|
||||
|
||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||
otherwise it will use regular LoRA.
|
||||
|
||||
By default, the adapter config and learned weights are saved in `adapters/`.
|
||||
You can specify the output location with `--adapter-path`.
|
||||
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
|
||||
#### Prompt Masking
|
||||
|
||||
The default training computes a loss for every token in the sample. You can
|
||||
ignore the prompt and compute loss for just the completion by passing
|
||||
`--mask-prompt`. Note this is only supported for `chat` and `completion`
|
||||
datasets. For `chat` datasets the final message in the message list is
|
||||
considered the completion. See the [dataset section](#Data) for more details.
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--adapter-path <path_to_adapters> \
|
||||
--data <path_to_data> \
|
||||
--test
|
||||
```
|
||||
|
||||
### Generate
|
||||
|
||||
For generation use `mlx_lm.generate`:
|
||||
|
||||
```shell
|
||||
mlx_lm.generate \
|
||||
--model <path_to_model> \
|
||||
--adapter-path <path_to_adapters> \
|
||||
--prompt "<your_model_prompt>"
|
||||
```
|
||||
|
||||
## Fuse
|
||||
|
||||
You can generate a model fused with the low-rank adapters using the
|
||||
`mlx_lm.fuse` command. This command also allows you to optionally:
|
||||
|
||||
- Upload the fused model to the Hugging Face Hub.
|
||||
- Export the fused model to GGUF. Note GGUF support is limited to Mistral,
|
||||
Mixtral, and Llama style models in fp16 precision.
|
||||
|
||||
To see supported options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse --help
|
||||
```
|
||||
|
||||
To generate the fused model run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse --model <path_to_model>
|
||||
```
|
||||
|
||||
This will by default load the adapters from `adapters/`, and save the fused
|
||||
model in the path `fused_model/`. All of these are configurable.
|
||||
|
||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||
useful for the sake of attribution and model versioning.
|
||||
|
||||
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--upload-repo mlx-community/my-lora-mistral-7b \
|
||||
--hf-path mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
To export a fused model to GGUF, run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--export-gguf
|
||||
```
|
||||
|
||||
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
|
||||
can specify the file name with `--gguf-path`.
|
||||
|
||||
## Data
|
||||
|
||||
The LoRA command expects you to provide a dataset with `--data`. The MLX
|
||||
Examples GitHub repo has an [example of the WikiSQL
|
||||
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
|
||||
correct format.
|
||||
|
||||
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
|
||||
Face.
|
||||
|
||||
### Local Datasets
|
||||
|
||||
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||
loader expects a `test.jsonl` in the data directory.
|
||||
|
||||
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
|
||||
data formats. Here are examples of these formats:
|
||||
|
||||
`chat`:
|
||||
|
||||
```jsonl
|
||||
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
|
||||
```
|
||||
|
||||
`tools`:
|
||||
|
||||
```jsonl
|
||||
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>View the expanded single data tool format</summary>
|
||||
|
||||
```jsonl
|
||||
{
|
||||
"messages": [
|
||||
{ "role": "user", "content": "What is the weather in San Francisco?" },
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_id",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and country, eg. San Francisco, USA"
|
||||
},
|
||||
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
The format for the `arguments` field in a function varies for different models.
|
||||
Common formats include JSON strings and dictionaries. The example provided
|
||||
follows the format used by
|
||||
[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples)
|
||||
and [Mistral
|
||||
AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct).
|
||||
A dictionary format is used in Hugging Face's [chat
|
||||
templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example).
|
||||
Refer to the documentation for the model you are fine-tuning for more details.
|
||||
|
||||
</details>
|
||||
|
||||
`completions`:
|
||||
|
||||
```jsonl
|
||||
{"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||
```
|
||||
|
||||
For the `completions` data format, a different key can be used for the prompt
|
||||
and completion by specifying the following in the YAML config:
|
||||
|
||||
```yaml
|
||||
prompt_feature: "input"
|
||||
completion_feature: "output"
|
||||
```
|
||||
|
||||
Here, `"input"` is the expected key instead of the default `"prompt"`, and
|
||||
`"output"` is the expected key instead of `"completion"`.
|
||||
|
||||
`text`:
|
||||
|
||||
```jsonl
|
||||
{"text": "This is an example for the model."}
|
||||
```
|
||||
|
||||
Note, the format is automatically determined by the dataset. Note also, keys
|
||||
in each line not expected by the loader will be ignored.
|
||||
|
||||
> [!NOTE]
|
||||
> Each example in the datasets must be on a single line. Do not put more than
|
||||
> one example per line and do not split an example across multiple lines.
|
||||
|
||||
### Hugging Face Datasets
|
||||
|
||||
To use Hugging Face datasets, first install the `datasets` package:
|
||||
|
||||
```
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
If the Hugging Face dataset is already in a supported format, you can specify
|
||||
it on the command line. For example, pass `--data mlx-community/wikisql` to
|
||||
train on the pre-formatted WikiwSQL data.
|
||||
|
||||
Otherwise, provide a mapping of keys in the dataset to the features MLX LM
|
||||
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
|
||||
example:
|
||||
|
||||
```yaml
|
||||
hf_dataset:
|
||||
name: "billsum"
|
||||
prompt_feature: "text"
|
||||
completion_feature: "summary"
|
||||
```
|
||||
|
||||
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||
dataset. Use `chat_feature` to specify the key for a chat dataset.
|
||||
|
||||
- To specify the train, valid, or test splits, set the corresponding
|
||||
`{train,valid,test}_split` argument.
|
||||
|
||||
You can specify a list of Hugging Face datasets with a list of records each
|
||||
with the same structure as above. For example:
|
||||
|
||||
```yaml
|
||||
hf_dataset:
|
||||
- name: "Open-Orca/OpenOrca"
|
||||
train_split: "train[:90%]"
|
||||
valid_split: "train[-10%:]"
|
||||
prompt_feature: "question"
|
||||
completion_feature: "response"
|
||||
- name: "trl-lib/ultrafeedback_binarized"
|
||||
train_split: "train[:90%]"
|
||||
valid_split: "train[-10%:]"
|
||||
chat_feature: "chosen"
|
||||
```
|
||||
|
||||
- Arguments specified in `config` will be passed as keyword arguments to
|
||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||
|
||||
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
|
||||
[chat
|
||||
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
are used. This applies the model's chat template by default. If the model does
|
||||
not have a chat template, then Hugging Face will use a default. For example,
|
||||
the final text in the `chat` example above with Hugging Face's default template
|
||||
becomes:
|
||||
|
||||
```text
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Hello.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
How can I assistant you today.<|im_end|>
|
||||
```
|
||||
|
||||
If you are unsure of the format to use, the `chat` or `completions` are good to
|
||||
start with. For custom requirements on the format of the dataset, use the
|
||||
`text` format to assemble the content yourself.
|
||||
|
||||
## Memory Issues
|
||||
|
||||
Fine-tuning a large model with LoRA requires a machine with a decent amount
|
||||
of memory. Here are some tips to reduce memory use should you need to do so:
|
||||
|
||||
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
|
||||
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
|
||||
more details.
|
||||
|
||||
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
|
||||
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||
things down a little, but will also reduce the memory use.
|
||||
|
||||
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
|
||||
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||
needed for back propagation. It may also reduce the quality of the
|
||||
fine-tuned model if you are fine-tuning with a lot of data.
|
||||
|
||||
4. Longer examples require more memory. If it makes sense for your data, one thing
|
||||
you can do is break your examples into smaller
|
||||
sequences when making the `{train, valid, test}.jsonl` files.
|
||||
|
||||
5. Gradient checkpointing lets you trade-off memory use (less) for computation
|
||||
(more) by recomputing instead of storing intermediate values needed by the
|
||||
backward pass. You can use gradient checkpointing by passing the
|
||||
`--grad-checkpoint` flag. Gradient checkpointing will be more helpful for
|
||||
larger batch sizes or sequence lengths with smaller or quantized models.
|
||||
|
||||
For example, for a machine with 32 GB the following should run reasonably fast:
|
||||
|
||||
```
|
||||
mlx_lm.lora \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--train \
|
||||
--batch-size 1 \
|
||||
--num-layers 4 \
|
||||
--data wikisql
|
||||
```
|
||||
|
||||
The above command on an M1 Max with 32 GB runs at about 250
|
||||
tokens-per-second, using the MLX Example
|
||||
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
|
||||
data set.
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
|
||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||
@@ -1,22 +0,0 @@
|
||||
# Managing Models
|
||||
|
||||
You can use `mlx-lm` to manage models downloaded locally in your machine. They
|
||||
are stored in the Hugging Face cache.
|
||||
|
||||
Scan models:
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --scan
|
||||
```
|
||||
|
||||
Specify a `--pattern` to get info on a single or specific set of models:
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
|
||||
```
|
||||
|
||||
To delete a model (or multiple models):
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
|
||||
```
|
||||
@@ -1,50 +0,0 @@
|
||||
# Model Merging
|
||||
|
||||
You can use `mlx-lm` to merge models and upload them to the Hugging
|
||||
Face hub or save them locally for LoRA fine tuning.
|
||||
|
||||
The main command is `mlx_lm.merge`:
|
||||
|
||||
```shell
|
||||
mlx_lm.merge --config config.yaml
|
||||
```
|
||||
|
||||
The merged model will be saved by default in `mlx_merged_model`. To see a
|
||||
full list of options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.merge --help
|
||||
```
|
||||
|
||||
Here is an example `config.yaml`:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- OpenPipe/mistral-ft-optimized-1218
|
||||
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||
method: slerp
|
||||
parameters:
|
||||
t:
|
||||
- filter: self_attn
|
||||
value: [0, 0.5, 0.3, 0.7, 1]
|
||||
- filter: mlp
|
||||
value: [1, 0.5, 0.7, 0.3, 0]
|
||||
- value: 0.5
|
||||
```
|
||||
|
||||
The `models` field is a list of Hugging Face repo ids. The first model in the
|
||||
list is treated as the base model into which the remaining models are merged.
|
||||
|
||||
The `method` field is the merging method. Right now `slerp` is the only
|
||||
supported method.
|
||||
|
||||
The `parameters` are the corresponding parameters for the given `method`.
|
||||
Each parameter is a list with `filter` determining which layer the parameter
|
||||
applies to and `value` determining the actual value used. The last item in
|
||||
the list without a `filter` field is the default.
|
||||
|
||||
If `value` is a list, it specifies the start and end values for the
|
||||
corresponding segment of blocks. In the example above, the models have 32
|
||||
blocks. For blocks 1-8, the layers with `self_attn` in the name will use the
|
||||
values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16)
|
||||
will use `np.linspace(0.5, 0.3, 8)`, and so on.
|
||||
@@ -1,16 +0,0 @@
|
||||
# DEPRECATED
|
||||
|
||||
The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm).
|
||||
|
||||
The package here will be removed soon. Send new contributions and issues to the MLX LM repo.
|
||||
|
||||
## Generate Text with MLX and :hugs: Hugging Face
|
||||
|
||||
This an example of large language model text generation that can pull models from
|
||||
the Hugging Face Hub.
|
||||
|
||||
For more information on this example, see the [README](../README.md) in the
|
||||
parent directory.
|
||||
|
||||
This package also supports fine tuning with LoRA or QLoRA. For more information
|
||||
see the [LoRA documentation](LORA.md).
|
||||
@@ -1,131 +0,0 @@
|
||||
# HTTP Model Server
|
||||
|
||||
You use `mlx-lm` to make an HTTP API for generating text with any supported
|
||||
model. The HTTP API is intended to be similar to the [OpenAI chat
|
||||
API](https://platform.openai.com/docs/api-reference).
|
||||
|
||||
> [!NOTE]
|
||||
> The MLX LM server is not recommended for production as it only implements
|
||||
> basic security checks.
|
||||
|
||||
Start the server with:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --model <path_to_model_or_hf_repo>
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
|
||||
```
|
||||
|
||||
This will start a text generation server on port `8080` of the `localhost`
|
||||
using Mistral 7B instruct. The model will be downloaded from the provided
|
||||
Hugging Face repo if it is not already in the local cache.
|
||||
|
||||
To see a full list of options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --help
|
||||
```
|
||||
|
||||
You can make a request to the model by running:
|
||||
|
||||
```shell
|
||||
curl localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
### Request Fields
|
||||
|
||||
- `messages`: An array of message objects representing the conversation
|
||||
history. Each message object should have a role (e.g. user, assistant) and
|
||||
content (the message text).
|
||||
|
||||
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
||||
the generated prompt. If not provided, the default mappings are used.
|
||||
|
||||
- `stop`: (Optional) An array of strings or a single string. These are
|
||||
sequences of tokens on which the generation should stop.
|
||||
|
||||
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
||||
to generate. Defaults to `100`.
|
||||
|
||||
- `stream`: (Optional) A boolean indicating if the response should be
|
||||
streamed. If true, responses are sent as they are generated. Defaults to
|
||||
false.
|
||||
|
||||
- `temperature`: (Optional) A float specifying the sampling temperature.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `repetition_context_size`: (Optional) The size of the context window for
|
||||
applying repetition penalty. Defaults to `20`.
|
||||
|
||||
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
||||
values. Defaults to `None`.
|
||||
|
||||
- `logprobs`: (Optional) An integer specifying the number of top tokens and
|
||||
corresponding log probabilities to return for each output in the generated
|
||||
sequence. If set, this can be any value between 1 and 10, inclusive.
|
||||
|
||||
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
|
||||
If the path is local is must be relative to the directory the server was
|
||||
started in.
|
||||
|
||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||
relative to the directory the server was started in.
|
||||
|
||||
### Response Fields
|
||||
|
||||
- `id`: A unique identifier for the chat.
|
||||
|
||||
- `system_fingerprint`: A unique identifier for the system.
|
||||
|
||||
- `object`: Any of "chat.completion", "chat.completion.chunk" (for
|
||||
streaming), or "text.completion".
|
||||
|
||||
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
|
||||
|
||||
- `created`: A time-stamp for when the request was processed.
|
||||
|
||||
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
|
||||
- `index`: The index in the list.
|
||||
- `logprobs`: A dictionary containing the fields:
|
||||
- `token_logprobs`: A list of the log probabilities for the generated
|
||||
tokens.
|
||||
- `tokens`: A list of the generated token ids.
|
||||
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
|
||||
top tokens (if requested) with their corresponding probabilities.
|
||||
- `finish_reason`: The reason the completion ended. This can be either of
|
||||
`"stop"` or `"length"`.
|
||||
- `message`: The text response from the model.
|
||||
|
||||
- `usage`: A dictionary containing the fields:
|
||||
- `prompt_tokens`: The number of prompt tokens processed.
|
||||
- `completion_tokens`: The number of tokens generated.
|
||||
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
|
||||
|
||||
### List Models
|
||||
|
||||
Use the `v1/models` endpoint to list available models:
|
||||
|
||||
```shell
|
||||
curl localhost:8080/v1/models -H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
This will return a list of locally available models where each model in the
|
||||
list contains the following fields:
|
||||
|
||||
- `id`: The Hugging Face repo id.
|
||||
- `created`: A time-stamp representing the model creation time.
|
||||
@@ -1,37 +0,0 @@
|
||||
### Packaging for PyPI
|
||||
|
||||
Install `build` and `twine`:
|
||||
|
||||
```
|
||||
pip install --user --upgrade build
|
||||
pip install --user --upgrade twine
|
||||
```
|
||||
|
||||
Generate the source distribution and wheel:
|
||||
|
||||
```
|
||||
python -m build
|
||||
```
|
||||
|
||||
> [!warning]
|
||||
> Use a test server first
|
||||
|
||||
#### Test Upload
|
||||
|
||||
Upload to test server:
|
||||
|
||||
```
|
||||
python -m twine upload --repository testpypi dist/*
|
||||
```
|
||||
|
||||
Install from test server and check that it works:
|
||||
|
||||
```
|
||||
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm
|
||||
```
|
||||
|
||||
#### Upload
|
||||
|
||||
```
|
||||
python -m twine upload dist/*
|
||||
```
|
||||
@@ -1,9 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import os
|
||||
|
||||
from ._version import __version__
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
from .utils import convert, generate, load, stream_generate
|
||||
@@ -1,3 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.21.6"
|
||||
@@ -1,161 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||
from .utils import generate_step, load
|
||||
|
||||
DEFAULT_QUANTIZED_KV_START = 5000
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Cache the state of a prompt to be reused with mlx_lm.generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Enable trusting remote code for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos-token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="End of sequence token for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-chat-template",
|
||||
action="store_true",
|
||||
help="Use the raw prompt without the tokenizer's chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-default-chat-template",
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the maximum key-value cache size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-cache-file",
|
||||
help="The file to save the prompt cache in",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
required=True,
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-bits",
|
||||
type=int,
|
||||
help="Number of bits for KV cache quantization. "
|
||||
"Defaults to no quantization.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
help="Group size for KV cache quantization.",
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized-kv-start",
|
||||
help="When --kv-bits is set, start quantizing the KV cache "
|
||||
"from this step onwards.",
|
||||
type=int,
|
||||
default=DEFAULT_QUANTIZED_KV_START,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
|
||||
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=False, continue_final_message=True
|
||||
)
|
||||
|
||||
else:
|
||||
prompt = tokenizer.encode(args.prompt)
|
||||
|
||||
cache = make_prompt_cache(model, args.max_kv_size)
|
||||
y = mx.array(prompt)
|
||||
|
||||
# Process the prompt
|
||||
start = time.time()
|
||||
max_msg_len = 0
|
||||
|
||||
def callback(processed, total_tokens):
|
||||
current = time.time()
|
||||
speed = processed / (current - start)
|
||||
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
||||
nonlocal max_msg_len
|
||||
max_msg_len = max(max_msg_len, len(msg))
|
||||
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
||||
|
||||
for _ in generate_step(
|
||||
y,
|
||||
model,
|
||||
max_tokens=0,
|
||||
prompt_cache=cache,
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
prompt_progress_callback=callback,
|
||||
):
|
||||
pass
|
||||
|
||||
print()
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
|
||||
|
||||
print("Saving...")
|
||||
metadata = {}
|
||||
metadata["model"] = args.model
|
||||
metadata["chat_template"] = json.dumps(tokenizer.chat_template)
|
||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,108 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import make_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import load, stream_generate
|
||||
|
||||
DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = None
|
||||
DEFAULT_MAX_TOKENS = 256
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=DEFAULT_SEED,
|
||||
help="PRNG seed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
help="Set the maximum key-value cache size",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
def print_help():
|
||||
print("The command list:")
|
||||
print("- 'q' to exit")
|
||||
print("- 'r' to reset the chat")
|
||||
print("- 'h' to display these commands")
|
||||
|
||||
print(f"[INFO] Starting chat session with {args.model}.")
|
||||
print_help()
|
||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||
while True:
|
||||
query = input(">> ")
|
||||
if query == "q":
|
||||
break
|
||||
if query == "r":
|
||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||
continue
|
||||
if query == "h":
|
||||
print_help()
|
||||
continue
|
||||
messages = [{"role": "user", "content": query}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
for response in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
sampler=make_sampler(args.temp, args.top_p),
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
print(response.text, flush=True, end="")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
from . import utils
|
||||
from .utils import convert
|
||||
|
||||
QUANT_RECIPES = [
|
||||
"mixed_2_6",
|
||||
"mixed_3_6",
|
||||
]
|
||||
|
||||
|
||||
def quant_args(arg):
|
||||
if arg not in QUANT_RECIPES:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid q-recipe {arg!r}. Choose from: {QUANT_RECIPES}"
|
||||
)
|
||||
else:
|
||||
return getattr(utils, arg)
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
|
||||
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
||||
parser.add_argument(
|
||||
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quant-predicate",
|
||||
help=f"Mixed-bit quantization recipe. Choices: {QUANT_RECIPES}",
|
||||
type=quant_args,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the non-quantized parameters.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dequantize",
|
||||
help="Dequantize a quantized model.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
convert(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,392 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Adapted from a PyTorch implementation by David Grangier
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import lm_eval
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.api.registry import register_model
|
||||
from tqdm import tqdm
|
||||
|
||||
from .models.cache import make_prompt_cache
|
||||
from .utils import load, stream_generate
|
||||
|
||||
PAD = 0
|
||||
|
||||
|
||||
def _len_longest_common_prefix(a, b):
|
||||
l = 0
|
||||
for item_a, item_b in zip(a, b):
|
||||
if item_a != item_b:
|
||||
break
|
||||
l += 1
|
||||
return l
|
||||
|
||||
|
||||
def _rstrip_until(s, untils):
|
||||
"""Limit a string <s> to the first occurrence of any substring in untils."""
|
||||
l = len(s)
|
||||
f = [s.find(u) for u in untils]
|
||||
f = [l if x < 0 else x for x in f]
|
||||
return s[: min(f)]
|
||||
|
||||
|
||||
def _pad_inputs(
|
||||
inputs,
|
||||
maxlen,
|
||||
genlen=0,
|
||||
pad_left=False,
|
||||
pad_multiple=32,
|
||||
truncate=False,
|
||||
):
|
||||
# pad the prompts to the left with at least genlen tokens.
|
||||
actual_maxlen = max(len(p) for p in inputs) + genlen
|
||||
if actual_maxlen > maxlen:
|
||||
if not truncate:
|
||||
raise ValueError("Inputs are too long.")
|
||||
else: # drop begining
|
||||
actual_maxlen = maxlen
|
||||
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
|
||||
if pad_multiple > 0:
|
||||
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
|
||||
maxlen *= pad_multiple
|
||||
assert PAD == 0
|
||||
lr = np.array((1, 0) if pad_left else (0, 1))
|
||||
return np.stack(
|
||||
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
|
||||
@register_model("mlxlm")
|
||||
class MLXLM(LM):
|
||||
def __init__(
|
||||
self,
|
||||
path_or_hf_repo: str,
|
||||
batch_size: int = 16,
|
||||
max_tokens: Optional[int] = None,
|
||||
use_chat_template: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._batch_size = batch_size
|
||||
self._model, self.tokenizer = load(path_or_hf_repo)
|
||||
self._max_tokens = max_tokens or self.tokenizer.model_max_length
|
||||
self.use_chat_template = use_chat_template or (
|
||||
self.tokenizer.chat_template is not None
|
||||
)
|
||||
|
||||
def _score_fn(self, inputs, tokenize=True, step_size=32):
|
||||
if tokenize:
|
||||
inputs = self._tokenize(inputs)
|
||||
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
|
||||
inputs = mx.array(inputs)
|
||||
inputs, targets = inputs[..., :-1], inputs[..., 1:]
|
||||
|
||||
cache = make_prompt_cache(self._model)
|
||||
|
||||
mask = targets != PAD
|
||||
|
||||
scores, is_greedy = [], []
|
||||
for i in range(0, inputs.shape[1], step_size):
|
||||
logits = self._model(inputs[:, i : i + step_size], cache=cache)
|
||||
|
||||
log_probs = nn.log_softmax(logits.astype(mx.float32))
|
||||
score = mx.take_along_axis(
|
||||
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
|
||||
)[..., 0]
|
||||
ig = mask[:, i : i + step_size] * (
|
||||
targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
|
||||
)
|
||||
|
||||
mx.eval(score, ig)
|
||||
mx.metal.clear_cache()
|
||||
|
||||
is_greedy.append(ig)
|
||||
scores.append(score)
|
||||
|
||||
scores = mx.concatenate(scores, axis=1)
|
||||
is_greedy = mx.concatenate(is_greedy, axis=1)
|
||||
|
||||
return scores, mask.sum(axis=-1), is_greedy
|
||||
|
||||
def _loglikelihood(self, texts, score_spans=None, tokenize=True):
|
||||
# sort by length to get batches with little padding.
|
||||
sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
|
||||
sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
|
||||
sorted_spans = None
|
||||
if score_spans is not None:
|
||||
sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
|
||||
|
||||
results = []
|
||||
for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
|
||||
batch = sorted_inputs[i : i + self._batch_size]
|
||||
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
|
||||
for j in range(len(batch)):
|
||||
if sorted_spans is None: # full sequence score
|
||||
mask = mx.arange(scores[j].shape[-1]) < length
|
||||
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
|
||||
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
|
||||
else: # subsequence score
|
||||
start, end = sorted_spans[i + j]
|
||||
score = scores[j][start:end].astype(mx.float32).sum()
|
||||
ig = is_greedy[j][start:end].astype(mx.int32).sum()
|
||||
length = end - start
|
||||
|
||||
results.append((score.item(), ig.item(), length))
|
||||
|
||||
# reorder the outputs
|
||||
inv_sort = np.argsort(sorted_indices)
|
||||
results = [results[inv_sort[i]] for i in range(len(results))]
|
||||
|
||||
return results
|
||||
|
||||
def _tokenize(self, texts):
|
||||
return [
|
||||
tuple(
|
||||
self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template)
|
||||
)
|
||||
for t in texts
|
||||
]
|
||||
|
||||
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
|
||||
"""Compute log-likelihood of generating a continuation from a context.
|
||||
Downstream tasks should attempt to use loglikelihood instead of other
|
||||
LM calls whenever possible.
|
||||
:param requests: list[Instance]
|
||||
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
||||
`context: str`
|
||||
Context string. Implementations of LM must be able to handle an
|
||||
empty context string.
|
||||
`continuation: str`
|
||||
The continuation over which log likelihood will be calculated. If
|
||||
there is a word boundary, the space should be in the continuation.
|
||||
For example, context="hello" continuation=" world" is correct.
|
||||
:return: list[tuple[float, bool]]
|
||||
A list of pairs (logprob, isgreedy)
|
||||
`logprob: float`
|
||||
The log probability of `continuation`.
|
||||
`isgreedy`:
|
||||
Whether `continuation` would be generated by greedy sampling from `context`.
|
||||
"""
|
||||
logging.info("Estimating loglikelihood for %d pairs." % len(requests))
|
||||
|
||||
# tokenize prefix and prefix + completion for all requests.
|
||||
tokenized = self._tokenize(
|
||||
[t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]]
|
||||
)
|
||||
|
||||
# max length (prefix + completion) and longest common prefix per question.
|
||||
length_stats = {}
|
||||
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
|
||||
max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8))
|
||||
length_stats[prefix] = (
|
||||
max(max_completed_l, len(completed)),
|
||||
min(min_prefix_l, _len_longest_common_prefix(prefix, completed)),
|
||||
)
|
||||
|
||||
# truncate requests for completed sequences longer than model context.
|
||||
shortened = []
|
||||
completion_spans = []
|
||||
long_completions = 0
|
||||
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
|
||||
max_completed_l, prefix_l = length_stats[prefix]
|
||||
# compute truncation length
|
||||
truncation = max(0, max_completed_l - self._max_tokens - 1)
|
||||
prefix_l = prefix_l - truncation
|
||||
if prefix_l <= 0:
|
||||
# completion too long, prefix is eliminated for some requests.
|
||||
long_completions += 1
|
||||
truncation = max(0, len(completed) - self._max_tokens - 1)
|
||||
prefix_l = 1
|
||||
# truncate the completed sequence
|
||||
completed = completed[truncation:]
|
||||
shortened.append(completed)
|
||||
# scores do not include initial bos, substract 1 to span bounds
|
||||
completion_spans.append((prefix_l - 1, len(completed) - 1))
|
||||
|
||||
if long_completions > 0:
|
||||
logging.info(
|
||||
f"Prefix eliminated for {long_completions} requests with "
|
||||
+ "completion longer than context."
|
||||
)
|
||||
|
||||
# model scoring, returns num_requests x (logp, is_greedy, length).
|
||||
results = self._loglikelihood(
|
||||
shortened,
|
||||
score_spans=completion_spans,
|
||||
tokenize=False,
|
||||
)
|
||||
return [(r[0], r[1] == r[2]) for r in results]
|
||||
|
||||
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
|
||||
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
|
||||
|
||||
def loglikelihood_rolling(self, requests) -> list[float]:
|
||||
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
||||
- We will use the full max context length of the model.
|
||||
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
||||
the max context length.
|
||||
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
||||
which may simply concatenate multiple documents together.
|
||||
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
||||
multiple chunks, the last input will still a full-sized context.
|
||||
Example:
|
||||
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
||||
Prefix: EOT
|
||||
Max context length: 4
|
||||
Resulting input/prediction pairs:
|
||||
INPUT: EOT 0 1 2
|
||||
PRED: 0 1 2 3
|
||||
INPUT: 3 4 5 6
|
||||
PRED: 4 5 6 7
|
||||
INPUT: 5 6 7 8
|
||||
PRED: 8 9
|
||||
Observe that:
|
||||
1. Each token is predicted exactly once
|
||||
2. For the last pair, we provide the full context, but only score the last two tokens
|
||||
:param requests: list[Instance]
|
||||
A list of Instance objects with property `args` which returns a tuple (context,).
|
||||
string: str
|
||||
String for which we are computing overall loglikelihood
|
||||
:return: list[tuple[float]]
|
||||
A list of tuples (logprob,)
|
||||
logprob: float
|
||||
The log probability of `context` conditioned on the EOT token.
|
||||
"""
|
||||
logging.info(
|
||||
"Estimating loglikelihood rolling for %d sequences." % len(requests)
|
||||
)
|
||||
inputs = [req.args[0] for req in requests]
|
||||
return [t[0] for t in self._loglikelihood(inputs)]
|
||||
|
||||
def generate_until(self, requests) -> list[str]:
|
||||
"""Generate greedily until a stopping sequence
|
||||
:param requests: list[Instance]
|
||||
A list of Instance objects with property `args` which returns a tuple (context, until).
|
||||
context: str
|
||||
Context string
|
||||
until: [str]
|
||||
The string sequences to generate until. These string sequences
|
||||
may each span across multiple tokens, or may be part of one token.
|
||||
:return: list[str]
|
||||
A list of strings continuation
|
||||
continuation: str
|
||||
The generated continuation.
|
||||
"""
|
||||
logging.info("Generating continuation for %d sequences." % len(requests))
|
||||
contexts, options = zip(*[req.args for req in requests])
|
||||
# contrary to the doc the second element of the tuple contains
|
||||
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
|
||||
completions = []
|
||||
|
||||
for context, opt in tqdm(zip(contexts, options), total=len(contexts)):
|
||||
until = opt["until"]
|
||||
context = self.tokenizer.encode(
|
||||
context, add_special_tokens=not self.use_chat_template
|
||||
)
|
||||
max_tokens = min(
|
||||
opt.get("max_gen_tokens", self._max_tokens),
|
||||
self.tokenizer.model_max_length - len(context),
|
||||
)
|
||||
text = ""
|
||||
for response in stream_generate(
|
||||
self._model, self.tokenizer, prompt=context, max_tokens=max_tokens
|
||||
):
|
||||
text += response.text
|
||||
if any(u in text for u in until):
|
||||
text = _rstrip_until(text, until)
|
||||
completions.append(text)
|
||||
break
|
||||
else:
|
||||
completions.append(text)
|
||||
return completions
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
"Evaluate an MLX model using lm-evaluation-harness."
|
||||
)
|
||||
parser.add_argument("--model", help="Model to evaluate", required=True)
|
||||
parser.add_argument("--tasks", nargs="+", required=True)
|
||||
parser.add_argument(
|
||||
"--output-dir", default=".", help="Output directory for result files."
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
||||
parser.add_argument("--num-shots", type=int, default=0, help="Number of shots")
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
help="Maximum nunber of tokens to generate. Defaults to the model's max context length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
default=100,
|
||||
help="Limit the number of examples per task.",
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
|
||||
parser.add_argument(
|
||||
"--fewshot-as-multiturn",
|
||||
action="store_true",
|
||||
help="Whether to provide the fewshot examples as a multiturn "
|
||||
"conversation or a single user turn.",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Specifies whether to apply a chat template to the prompt. If "
|
||||
"the model has a chat template, this defaults to `True`, "
|
||||
"otherwise `False`.",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Silence tokenizer warnings
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
lm = MLXLM(
|
||||
args.model,
|
||||
batch_size=args.batch_size,
|
||||
max_tokens=args.max_tokens,
|
||||
use_chat_template=args.apply_chat_template,
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model=lm,
|
||||
tasks=args.tasks,
|
||||
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
||||
apply_chat_template=lm.use_chat_template,
|
||||
num_fewshot=args.num_shots,
|
||||
limit=args.limit,
|
||||
random_seed=args.seed,
|
||||
numpy_random_seed=args.seed,
|
||||
torch_random_seed=args.seed,
|
||||
fewshot_random_seed=args.seed,
|
||||
)
|
||||
|
||||
model_name = args.model.replace("/", "_")
|
||||
task_names = "_".join(args.tasks)
|
||||
ver = version("lm_eval")
|
||||
filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json"
|
||||
output_path = output_dir / filename
|
||||
output_path.write_text(json.dumps(results["results"], indent=4))
|
||||
print("Results:")
|
||||
for result in results["results"].values():
|
||||
print(json.dumps(result, indent=4))
|
||||
@@ -1,47 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
An example of a multi-turn chat with prompt caching.
|
||||
"""
|
||||
|
||||
from mlx_lm import generate, load
|
||||
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
|
||||
|
||||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||
|
||||
# Make the initial prompt cache for the model
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# User turn
|
||||
prompt = "Hi my name is <Name>."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
# User turn
|
||||
prompt = "What's my name?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
# Save the prompt cache to disk to reuse it at a later time
|
||||
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
|
||||
|
||||
# Load the prompt cache from disk
|
||||
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")
|
||||
@@ -1,33 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from mlx_lm import generate, load
|
||||
|
||||
# Specify the checkpoint
|
||||
checkpoint = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
|
||||
# Load the corresponding model and tokenizer
|
||||
model, tokenizer = load(path_or_hf_repo=checkpoint)
|
||||
|
||||
# Specify the prompt and conversation history
|
||||
prompt = "Why is the sky blue?"
|
||||
conversation = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Transform the prompt into the chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Specify the maximum number of tokens
|
||||
max_tokens = 1_000
|
||||
|
||||
# Specify if tokens and timing information will be printed
|
||||
verbose = True
|
||||
|
||||
# Generate a response with the specified settings
|
||||
response = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
verbose=verbose,
|
||||
)
|
||||
@@ -1,89 +0,0 @@
|
||||
# The path to the local model directory or Hugging Face repo.
|
||||
model: "mlx_model"
|
||||
|
||||
# Whether or not to train (boolean)
|
||||
train: true
|
||||
|
||||
# The fine-tuning method: "lora", "dora", or "full".
|
||||
fine_tune_type: lora
|
||||
|
||||
# The Optimizer with its possible inputs
|
||||
optimizer: adamw
|
||||
# optimizer_config:
|
||||
# adamw:
|
||||
# betas: [0.9, 0.98]
|
||||
# eps: 1e-6
|
||||
# weight_decay: 0.05
|
||||
# bias_correction: true
|
||||
|
||||
# Directory with {train, valid, test}.jsonl files
|
||||
data: "/path/to/training/data"
|
||||
|
||||
# The PRNG seed
|
||||
seed: 0
|
||||
|
||||
# Number of layers to fine-tune
|
||||
num_layers: 16
|
||||
|
||||
# Minibatch size.
|
||||
batch_size: 4
|
||||
|
||||
# Iterations to train for.
|
||||
iters: 1000
|
||||
|
||||
# Number of validation batches, -1 uses the entire validation set.
|
||||
val_batches: 25
|
||||
|
||||
# Adam learning rate.
|
||||
learning_rate: 1e-5
|
||||
|
||||
# Number of training steps between loss reporting.
|
||||
steps_per_report: 10
|
||||
|
||||
# Number of training steps between validations.
|
||||
steps_per_eval: 200
|
||||
|
||||
# Load path to resume training with the given adapter weights.
|
||||
resume_adapter_file: null
|
||||
|
||||
# Save/load path for the trained adapter weights.
|
||||
adapter_path: "adapters"
|
||||
|
||||
# Save the model every N iterations.
|
||||
save_every: 100
|
||||
|
||||
# Evaluate on the test set after training
|
||||
test: false
|
||||
|
||||
# Number of test set batches, -1 uses the entire test set.
|
||||
test_batches: 100
|
||||
|
||||
# Maximum sequence length.
|
||||
max_seq_length: 2048
|
||||
|
||||
# Use gradient checkpointing to reduce memory use.
|
||||
grad_checkpoint: false
|
||||
|
||||
# LoRA parameters can only be specified in a config file
|
||||
lora_parameters:
|
||||
# The layer keys to apply LoRA to.
|
||||
# These will be applied for the last lora_layers
|
||||
keys: ["self_attn.q_proj", "self_attn.v_proj"]
|
||||
rank: 8
|
||||
scale: 20.0
|
||||
dropout: 0.0
|
||||
|
||||
# Schedule can only be specified in a config file, uncomment to use.
|
||||
#lr_schedule:
|
||||
# name: cosine_decay
|
||||
# warmup: 100 # 0 for no warmup
|
||||
# warmup_init: 1e-7 # 0 if not specified
|
||||
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
|
||||
|
||||
#hf_dataset:
|
||||
# name: "billsum"
|
||||
# train_split: "train[:1000]"
|
||||
# valid_split: "train[-100:]"
|
||||
# prompt_feature: "text"
|
||||
# completion_feature: "summary"
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
models:
|
||||
- OpenPipe/mistral-ft-optimized-1218
|
||||
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||
method: slerp
|
||||
parameters:
|
||||
t:
|
||||
- filter: self_attn
|
||||
value: [0, 0.5, 0.3, 0.7, 1]
|
||||
- filter: mlp
|
||||
value: [1, 0.5, 0.7, 0.3, 0]
|
||||
- value: 0.5
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
|
||||
```
|
||||
mlx.launch \
|
||||
--hostfile /path/to/hosts.txt \
|
||||
--backend mpi \
|
||||
/path/to/pipeline_generate.py \
|
||||
--prompt "hello world"
|
||||
```
|
||||
|
||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||
documentation:
|
||||
|
||||
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import load, stream_generate
|
||||
from mlx_lm.utils import load_model, load_tokenizer
|
||||
|
||||
|
||||
def download(repo: str, allow_patterns: list[str]) -> Path:
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def shard_and_load(repo):
|
||||
# Get model path with everything but weight safetensors
|
||||
model_path = download(
|
||||
args.model,
|
||||
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||
)
|
||||
|
||||
# Lazy load and shard model to figure out
|
||||
# which weights we need
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
|
||||
group = mx.distributed.init(backend="mpi")
|
||||
rank = group.rank()
|
||||
model.model.pipeline(group)
|
||||
|
||||
# Figure out which files we need for the local shard
|
||||
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
||||
weight_index = json.load(fid)["weight_map"]
|
||||
|
||||
local_files = set()
|
||||
for k, _ in tree_flatten(model.parameters()):
|
||||
local_files.add(weight_index[k])
|
||||
|
||||
# Download weights for local shard
|
||||
download(args.model, allow_patterns=local_files)
|
||||
|
||||
# Load and shard the model, and load the weights
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
model.model.pipeline(group)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Synchronize processes before generation to avoid timeout if downloading
|
||||
# model for the first time.
|
||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx-community/DeepSeek-R1-3bit",
|
||||
help="HF repo or path to local model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
default="Write a quicksort in C++.",
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
group = mx.distributed.init(backend="mpi")
|
||||
rank = group.rank()
|
||||
|
||||
def rprint(*args, **kwargs):
|
||||
if rank == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
model, tokenizer = shard_and_load(args.model)
|
||||
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
for response in stream_generate(
|
||||
model, tokenizer, prompt, max_tokens=args.max_tokens
|
||||
):
|
||||
rprint(response.text, end="", flush=True)
|
||||
|
||||
rprint()
|
||||
rprint("=" * 10)
|
||||
rprint(
|
||||
f"Prompt: {response.prompt_tokens} tokens, "
|
||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
rprint(
|
||||
f"Generation: {response.generation_tokens} tokens, "
|
||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||
@@ -1,73 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import json
|
||||
|
||||
from mlx_lm import generate, load
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
|
||||
# Specify the checkpoint
|
||||
checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit"
|
||||
|
||||
# Load the corresponding model and tokenizer
|
||||
model, tokenizer = load(path_or_hf_repo=checkpoint)
|
||||
|
||||
|
||||
# An example tool, make sure to include a docstring and type hints
|
||||
def multiply(a: float, b: float):
|
||||
"""
|
||||
A function that multiplies two numbers
|
||||
|
||||
Args:
|
||||
a: The first number to multiply
|
||||
b: The second number to multiply
|
||||
"""
|
||||
return a * b
|
||||
|
||||
|
||||
tools = {"multiply": multiply}
|
||||
|
||||
# Specify the prompt and conversation history
|
||||
prompt = "Multiply 12234585 and 48838483920."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tools=list(tools.values())
|
||||
)
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# Generate the initial tool call:
|
||||
response = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=2048,
|
||||
verbose=True,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
# Parse the tool call:
|
||||
# (Note, the tool call format is model specific)
|
||||
tool_open = "<tool_call>"
|
||||
tool_close = "</tool_call>"
|
||||
start_tool = response.find(tool_open) + len(tool_open)
|
||||
end_tool = response.find(tool_close)
|
||||
tool_call = json.loads(response[start_tool:end_tool].strip())
|
||||
tool_result = tools[tool_call["name"]](**tool_call["arguments"])
|
||||
|
||||
# Put the tool result in the prompt
|
||||
messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
# Generate the final response:
|
||||
response = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=2048,
|
||||
verbose=True,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
@@ -1,130 +0,0 @@
|
||||
import argparse
|
||||
import glob
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from .gguf import convert_to_gguf
|
||||
from .tuner.dora import DoRAEmbedding, DoRALinear
|
||||
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
||||
from .tuner.utils import dequantize, load_adapters
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
save_config,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fuse fine-tuned adapters into the base model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
default="fused_model",
|
||||
help="The path to save the fused model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
default="adapters",
|
||||
help="Path to the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--de-quantize",
|
||||
help="Generate a de-quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-gguf",
|
||||
help="Export model weights in GGUF format.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gguf-path",
|
||||
help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.",
|
||||
default="ggml-model-f16.gguf",
|
||||
type=str,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print("Loading pretrained model")
|
||||
args = parse_arguments()
|
||||
|
||||
model_path = get_model_path(args.model)
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
model.freeze()
|
||||
model = load_adapters(model, args.adapter_path)
|
||||
|
||||
fused_linears = [
|
||||
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
||||
]
|
||||
|
||||
if fused_linears:
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
|
||||
if args.de_quantize:
|
||||
print("De-quantizing model")
|
||||
model = dequantize(model)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
save_path = Path(args.save_path)
|
||||
|
||||
save_weights(save_path, weights)
|
||||
|
||||
py_files = glob.glob(str(model_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, save_path)
|
||||
|
||||
tokenizer.save_pretrained(save_path)
|
||||
|
||||
if args.de_quantize:
|
||||
config.pop("quantization", None)
|
||||
|
||||
save_config(config, config_path=save_path / "config.json")
|
||||
|
||||
if args.export_gguf:
|
||||
model_type = config["model_type"]
|
||||
if model_type not in ["llama", "mixtral", "mistral"]:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} not supported for GGUF conversion."
|
||||
)
|
||||
convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
|
||||
|
||||
if args.upload_repo is not None:
|
||||
hf_path = args.hf_path or (
|
||||
args.model if not Path(args.model).exists() else None
|
||||
)
|
||||
if hf_path is None:
|
||||
raise ValueError(
|
||||
"Must provide original Hugging Face repo to upload local model."
|
||||
)
|
||||
upload_to_hub(args.save_path, args.upload_repo, hf_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,287 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_MIN_P = 0.0
|
||||
DEFAULT_MIN_TOKENS_TO_KEEP = 1
|
||||
DEFAULT_SEED = None
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
DEFAULT_QUANTIZED_KV_START = 5000
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
return string.lower() not in ["false", "f"]
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="LLM inference script")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help=(
|
||||
"The path to the local model directory or Hugging Face repo. "
|
||||
f"If no model is specified, then {DEFAULT_MODEL} is used."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-eos-token",
|
||||
type=str,
|
||||
default=(),
|
||||
nargs="+",
|
||||
help="Add tokens in the list of eos tokens that stop generation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--system-prompt",
|
||||
default=None,
|
||||
help="System prompt to be used for the chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
default=DEFAULT_PROMPT,
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-response",
|
||||
default=None,
|
||||
help="Prefill response to be used for the chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-tokens-to-keep",
|
||||
type=int,
|
||||
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||
help="Minimum tokens to keep for min-p sampling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=DEFAULT_SEED,
|
||||
help="PRNG seed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore-chat-template",
|
||||
action="store_true",
|
||||
help="Use the raw prompt without the tokenizer's chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-default-chat-template",
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template-config",
|
||||
help="Additional config for `apply_chat_template`. Should be a dictionary of"
|
||||
" string keys to values represented as a JSON decodable string.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
help="Set the maximum key-value cache size",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-cache-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-bits",
|
||||
type=int,
|
||||
help="Number of bits for KV cache quantization. "
|
||||
"Defaults to no quantization.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-group-size",
|
||||
type=int,
|
||||
help="Group size for KV cache quantization.",
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantized-kv-start",
|
||||
help="When --kv-bits is set, start quantizing the KV cache "
|
||||
"from this step onwards.",
|
||||
type=int,
|
||||
default=DEFAULT_QUANTIZED_KV_START,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-model",
|
||||
type=str,
|
||||
help="A model to be used for speculative decoding.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-draft-tokens",
|
||||
type=int,
|
||||
help="Number of tokens to draft when using speculative decoding.",
|
||||
default=3,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file,
|
||||
return_metadata=True,
|
||||
)
|
||||
if isinstance(prompt_cache[0], QuantizedKVCache):
|
||||
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
|
||||
raise ValueError(
|
||||
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
|
||||
)
|
||||
if args.kv_group_size != prompt_cache[0].group_size:
|
||||
raise ValueError(
|
||||
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
|
||||
)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
|
||||
model_path = args.model
|
||||
if using_cache:
|
||||
if model_path is None:
|
||||
model_path = metadata["model"]
|
||||
elif model_path != metadata["model"]:
|
||||
raise ValueError(
|
||||
f"Providing a different model ({model_path}) than that "
|
||||
f"used to create the prompt cache ({metadata['model']}) "
|
||||
"is an error."
|
||||
)
|
||||
model_path = model_path or DEFAULT_MODEL
|
||||
|
||||
model, tokenizer = load(
|
||||
model_path,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
for eos_token in args.extra_eos_token:
|
||||
tokenizer.add_eos_token(eos_token)
|
||||
|
||||
template_kwargs = {}
|
||||
if args.chat_template_config is not None:
|
||||
template_kwargs = json.loads(args.chat_template_config)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = json.loads(metadata["chat_template"])
|
||||
|
||||
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
|
||||
prompt = sys.stdin.read() if prompt == "-" else prompt
|
||||
if not args.ignore_chat_template and tokenizer.chat_template is not None:
|
||||
if args.system_prompt is not None:
|
||||
messages = [{"role": "system", "content": args.system_prompt}]
|
||||
else:
|
||||
messages = []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
has_prefill = args.prefill_response is not None
|
||||
if has_prefill:
|
||||
messages.append({"role": "assistant", "content": args.prefill_response})
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
continue_final_message=has_prefill,
|
||||
add_generation_prompt=not has_prefill,
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
# stored kv cache.
|
||||
if using_cache:
|
||||
messages[-1]["content"] = "<query>"
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
continue_final_message=has_prefill,
|
||||
add_generation_prompt=not has_prefill,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
else:
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
if args.draft_model is not None:
|
||||
draft_model, draft_tokenizer = load(args.draft_model)
|
||||
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
|
||||
raise ValueError("Draft model tokenizer does not match model tokenizer.")
|
||||
else:
|
||||
draft_model = None
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
sampler=sampler,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=args.num_draft_tokens,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,314 +0,0 @@
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class TokenType(IntEnum):
|
||||
NORMAL = 1
|
||||
UNKNOWN = 2
|
||||
CONTROL = 3
|
||||
USER_DEFINED = 4
|
||||
UNUSED = 5
|
||||
BYTE = 6
|
||||
|
||||
|
||||
class GGMLFileType(IntEnum):
|
||||
GGML_TYPE_F16 = 1
|
||||
|
||||
|
||||
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
|
||||
class HfVocab:
|
||||
def __init__(
|
||||
self,
|
||||
fname_tokenizer: Path,
|
||||
fname_added_tokens: Optional[Union[Path, None]] = None,
|
||||
) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
fname_tokenizer,
|
||||
cache_dir=fname_tokenizer,
|
||||
local_files_only=True,
|
||||
)
|
||||
self.added_tokens_list = []
|
||||
self.added_tokens_dict = dict()
|
||||
self.added_tokens_ids = set()
|
||||
for tok, tokidx in sorted(
|
||||
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
|
||||
):
|
||||
if tokidx >= self.tokenizer.vocab_size:
|
||||
self.added_tokens_list.append(tok)
|
||||
self.added_tokens_dict[tok] = tokidx
|
||||
self.added_tokens_ids.add(tokidx)
|
||||
self.specials = {
|
||||
tok: self.tokenizer.get_vocab()[tok]
|
||||
for tok in self.tokenizer.all_special_tokens
|
||||
}
|
||||
self.special_ids = set(self.tokenizer.all_special_ids)
|
||||
self.vocab_size_base = self.tokenizer.vocab_size
|
||||
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
self.fname_added_tokens = fname_added_tokens
|
||||
|
||||
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
reverse_vocab = {
|
||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||
}
|
||||
for token_id in range(self.vocab_size_base):
|
||||
if token_id in self.added_tokens_ids:
|
||||
continue
|
||||
token_text = reverse_vocab[token_id]
|
||||
yield token_text, self.get_token_score(token_id), self.get_token_type(
|
||||
token_id, token_text, self.special_ids
|
||||
)
|
||||
|
||||
def get_token_type(
|
||||
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||
) -> TokenType:
|
||||
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
return TokenType.BYTE
|
||||
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
|
||||
|
||||
def get_token_score(self, token_id: int) -> float:
|
||||
return -1000.0
|
||||
|
||||
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
if text in self.specials:
|
||||
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
|
||||
score = self.get_token_score(self.specials[text])
|
||||
else:
|
||||
toktype = TokenType.USER_DEFINED
|
||||
score = -1000.0
|
||||
yield text, score, toktype
|
||||
|
||||
def has_newline_token(self):
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||
|
||||
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
yield from self.hf_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
@staticmethod
|
||||
def load(path: Path) -> "HfVocab":
|
||||
added_tokens_path = path.parent / "added_tokens.json"
|
||||
return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
||||
|
||||
|
||||
def translate_weight_names(name):
|
||||
name = name.replace("model.layers.", "blk.")
|
||||
# for mixtral gate
|
||||
name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
|
||||
# for mixtral experts ffns
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight"
|
||||
replacement = r"ffn_gate.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight"
|
||||
replacement = r"ffn_down.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight"
|
||||
replacement = r"ffn_up.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
|
||||
name = name.replace("mlp.gate_proj", "ffn_gate")
|
||||
name = name.replace("mlp.down_proj", "ffn_down")
|
||||
name = name.replace("mlp.up_proj", "ffn_up")
|
||||
name = name.replace("self_attn.q_proj", "attn_q")
|
||||
name = name.replace("self_attn.k_proj", "attn_k")
|
||||
name = name.replace("self_attn.v_proj", "attn_v")
|
||||
name = name.replace("self_attn.o_proj", "attn_output")
|
||||
name = name.replace("input_layernorm", "attn_norm")
|
||||
name = name.replace("post_attention_layernorm", "ffn_norm")
|
||||
name = name.replace("model.embed_tokens", "token_embd")
|
||||
name = name.replace("model.norm", "output_norm")
|
||||
name = name.replace("lm_head", "output")
|
||||
return name
|
||||
|
||||
|
||||
def permute_weights(weights, n_head, n_head_kv=None):
|
||||
if n_head_kv is not None and n_head != n_head_kv:
|
||||
n_head = n_head_kv
|
||||
reshaped = weights.reshape(
|
||||
n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
|
||||
)
|
||||
swapped = reshaped.swapaxes(1, 2)
|
||||
final_shape = weights.shape
|
||||
return swapped.reshape(final_shape)
|
||||
|
||||
|
||||
def prepare_metadata(config, vocab):
|
||||
metadata = {
|
||||
"general.name": "llama",
|
||||
"llama.context_length": (
|
||||
mx.array(config["max_position_embeddings"], dtype=mx.uint32)
|
||||
if config.get("max_position_embeddings") is not None
|
||||
else None
|
||||
),
|
||||
"llama.embedding_length": (
|
||||
mx.array(config["hidden_size"], dtype=mx.uint32)
|
||||
if config.get("hidden_size") is not None
|
||||
else None
|
||||
),
|
||||
"llama.block_count": (
|
||||
mx.array(config["num_hidden_layers"], dtype=mx.uint32)
|
||||
if config.get("num_hidden_layers") is not None
|
||||
else None
|
||||
),
|
||||
"llama.feed_forward_length": (
|
||||
mx.array(config["intermediate_size"], dtype=mx.uint32)
|
||||
if config.get("intermediate_size") is not None
|
||||
else None
|
||||
),
|
||||
"llama.rope.dimension_count": (
|
||||
mx.array(
|
||||
config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32
|
||||
)
|
||||
if config.get("hidden_size") is not None
|
||||
and config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.head_count": (
|
||||
mx.array(config["num_attention_heads"], dtype=mx.uint32)
|
||||
if config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.head_count_kv": (
|
||||
mx.array(
|
||||
config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
if config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.expert_count": (
|
||||
mx.array(config.get("num_local_experts", None), dtype=mx.uint32)
|
||||
if config.get("num_local_experts") is not None
|
||||
else None
|
||||
),
|
||||
"llama.expert_used_count": (
|
||||
mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32)
|
||||
if config.get("num_experts_per_tok") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.layer_norm_rms_epsilon": (
|
||||
mx.array(config.get("rms_norm_eps", 1e-05))
|
||||
if config.get("rms_norm_eps") is not None
|
||||
else None
|
||||
),
|
||||
"llama.rope.freq_base": (
|
||||
mx.array(config.get("rope_theta", 10000), dtype=mx.float32)
|
||||
if config.get("rope_theta") is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
rope_scaling = config.get("rope_scaling")
|
||||
if rope_scaling is not None and (typ := rope_scaling.get("type")):
|
||||
rope_factor = rope_scaling.get("factor")
|
||||
f_rope_scale = rope_factor
|
||||
if typ == "linear":
|
||||
rope_scaling_type = "linear"
|
||||
metadata["llama.rope.scaling.type"] = rope_scaling_type
|
||||
metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale)
|
||||
|
||||
metadata["general.file_type"] = mx.array(
|
||||
GGMLFileType.GGML_TYPE_F16.value,
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
metadata["general.quantization_version"] = mx.array(
|
||||
GGMLFileType.GGML_TYPE_F16.value,
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1]
|
||||
metadata["general.architecture"] = "llama"
|
||||
metadata["general.alignment"] = mx.array(32, dtype=mx.uint32)
|
||||
|
||||
# add metadata for vocab
|
||||
metadata["tokenizer.ggml.model"] = "llama"
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype.value)
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
metadata["tokenizer.ggml.tokens"] = tokens
|
||||
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
|
||||
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
|
||||
if vocab.tokenizer.bos_token_id is not None:
|
||||
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
|
||||
vocab.tokenizer.bos_token_id, dtype=mx.uint32
|
||||
)
|
||||
if vocab.tokenizer.eos_token_id is not None:
|
||||
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
|
||||
vocab.tokenizer.eos_token_id, dtype=mx.uint32
|
||||
)
|
||||
if vocab.tokenizer.unk_token_id is not None:
|
||||
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
|
||||
vocab.tokenizer.unk_token_id, dtype=mx.uint32
|
||||
)
|
||||
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
return metadata
|
||||
|
||||
|
||||
def convert_to_gguf(
|
||||
model_path: Union[str, Path],
|
||||
weights: dict,
|
||||
config: dict,
|
||||
output_file_path: str,
|
||||
):
|
||||
if isinstance(model_path, str):
|
||||
model_path = Path(model_path)
|
||||
|
||||
quantization = config.get("quantization", None)
|
||||
if quantization:
|
||||
raise NotImplementedError(
|
||||
"Conversion of quantized models is not yet supported."
|
||||
)
|
||||
print("Converting to GGUF format")
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention
|
||||
weights = {
|
||||
k: (
|
||||
permute_weights(
|
||||
v, config["num_attention_heads"], config["num_attention_heads"]
|
||||
)
|
||||
if "self_attn.q_proj.weight" in k
|
||||
else (
|
||||
permute_weights(
|
||||
v, config["num_attention_heads"], config["num_key_value_heads"]
|
||||
)
|
||||
if "self_attn.k_proj.weight" in k
|
||||
else v
|
||||
)
|
||||
)
|
||||
for k, v in weights.items()
|
||||
}
|
||||
|
||||
# rename weights for gguf format
|
||||
weights = {translate_weight_names(k): v for k, v in weights.items()}
|
||||
|
||||
if not (model_path / "tokenizer.json").exists():
|
||||
raise ValueError("Tokenizer json not found")
|
||||
|
||||
vocab = HfVocab.load(model_path)
|
||||
metadata = prepare_metadata(config, vocab)
|
||||
|
||||
weights = {
|
||||
k: (
|
||||
v.astype(mx.float32).astype(mx.float16)
|
||||
if v.dtype == mx.bfloat16
|
||||
else v.astype(mx.float32) if "norm" in k else v
|
||||
)
|
||||
for k, v in weights.items()
|
||||
}
|
||||
|
||||
output_file_path = output_file_path
|
||||
mx.save_gguf(output_file_path, weights, metadata)
|
||||
print(f"Converted GGUF model saved as: {output_file_path}")
|
||||
@@ -1,335 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import (
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
load_adapters,
|
||||
print_trainable_parameters,
|
||||
)
|
||||
from .utils import load, save_config
|
||||
|
||||
yaml_loader = yaml.SafeLoader
|
||||
yaml_loader.add_implicit_resolver(
|
||||
"tag:yaml.org,2002:float",
|
||||
re.compile(
|
||||
"""^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$""",
|
||||
re.X,
|
||||
),
|
||||
list("-+0123456789."),
|
||||
)
|
||||
|
||||
CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"fine_tune_type": "lora",
|
||||
"optimizer": "adam",
|
||||
"optimizer_config": {
|
||||
"adam": {},
|
||||
"adamw": {},
|
||||
},
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"num_layers": 16,
|
||||
"batch_size": 4,
|
||||
"iters": 1000,
|
||||
"val_batches": 25,
|
||||
"learning_rate": 1e-5,
|
||||
"steps_per_report": 10,
|
||||
"steps_per_eval": 200,
|
||||
"resume_adapter_file": None,
|
||||
"adapter_path": "adapters",
|
||||
"save_every": 100,
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
"max_seq_length": 2048,
|
||||
"config": None,
|
||||
"grad_checkpoint": False,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"mask_prompt": False,
|
||||
}
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
|
||||
# Training args
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
help=(
|
||||
"Directory with {train, valid, test}.jsonl files or the name "
|
||||
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fine-tune-type",
|
||||
type=str,
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
choices=["adam", "adamw"],
|
||||
default=None,
|
||||
help="Optimizer to use for training: adam or adamw",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
|
||||
parser.add_argument("--iters", type=int, help="Iterations to train for.")
|
||||
parser.add_argument(
|
||||
"--val-batches",
|
||||
type=int,
|
||||
help="Number of validation batches, -1 uses the entire validation set.",
|
||||
)
|
||||
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
|
||||
parser.add_argument(
|
||||
"--steps-per-report",
|
||||
type=int,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps-per-eval",
|
||||
type=int,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
help="Load path to resume training from the given fine-tuned weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Save/load path for the fine-tuned weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
type=int,
|
||||
help="Save the model every N iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batches",
|
||||
type=int,
|
||||
help="Number of test set batches, -1 uses the entire test set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-seq-length",
|
||||
type=int,
|
||||
help="Maximum sequence length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
type=str,
|
||||
help="A YAML configuration file with the training options",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-checkpoint",
|
||||
action="store_true",
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
return parser
|
||||
|
||||
|
||||
def train_model(
|
||||
args,
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
train_set,
|
||||
valid_set,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
model.freeze()
|
||||
if args.num_layers > len(model.layers):
|
||||
raise ValueError(
|
||||
f"Requested to train {args.num_layers} layers "
|
||||
f"but the model only has {len(model.layers)} layers."
|
||||
)
|
||||
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-max(args.num_layers, 0) :]:
|
||||
l.unfreeze()
|
||||
elif args.fine_tune_type in ["lora", "dora"]:
|
||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
args.num_layers,
|
||||
args.lora_parameters,
|
||||
use_dora=(args.fine_tune_type == "dora"),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
|
||||
|
||||
# Resume from weights if provided
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
print_trainable_parameters(model)
|
||||
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
|
||||
model.train()
|
||||
|
||||
# Initialize the selected optimizer
|
||||
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
|
||||
optimizer_name = args.optimizer.lower()
|
||||
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||
|
||||
if optimizer_name == "adam":
|
||||
opt_class = optim.Adam
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = optim.AdamW
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||
|
||||
opt = opt_class(learning_rate=lr, **optimizer_config)
|
||||
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||
model.eval()
|
||||
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
def run(args, training_callback: TrainingCallback = None):
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||
|
||||
if args.test and not args.train:
|
||||
# Allow testing without LoRA layers by providing empty path
|
||||
if args.adapter_path != "":
|
||||
load_adapters(model, args.adapter_path)
|
||||
|
||||
elif args.train:
|
||||
print("Training")
|
||||
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
|
||||
else:
|
||||
raise ValueError("Must provide at least one of --train or --test")
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
evaluate_model(args, model, tokenizer, test_set)
|
||||
|
||||
|
||||
def main():
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
config = args.config
|
||||
args = vars(args)
|
||||
if config:
|
||||
print("Loading configuration file", config)
|
||||
with open(config, "r") as file:
|
||||
config = yaml.load(file, yaml_loader)
|
||||
# Prefer parameters from command-line arguments
|
||||
for k, v in config.items():
|
||||
if args.get(k, None) is None:
|
||||
args[k] = v
|
||||
|
||||
# Update defaults for unspecified parameters
|
||||
for k, v in CONFIG_DEFAULTS.items():
|
||||
if args.get(k, None) is None:
|
||||
args[k] = v
|
||||
run(types.SimpleNamespace(**args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,139 +0,0 @@
|
||||
import argparse
|
||||
from typing import List, Union
|
||||
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
|
||||
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
||||
"""
|
||||
Inspired by:
|
||||
- stackoverflow.com/a/8356620/593036
|
||||
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
||||
"""
|
||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||
lines = []
|
||||
lines.append(row_format.format(*headers))
|
||||
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
||||
for row in rows:
|
||||
lines.append(row_format.format(*row))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def ask_for_confirmation(message: str) -> bool:
|
||||
"""Ask user for confirmation with Y/N prompt.
|
||||
Returns True for Y/yes, False for N/no/empty."""
|
||||
y = ("y", "yes", "1")
|
||||
n = ("n", "no", "0", "")
|
||||
full_message = f"{message} (y/n) "
|
||||
while True:
|
||||
answer = input(full_message).lower()
|
||||
if answer in y:
|
||||
return True
|
||||
if answer in n:
|
||||
return False
|
||||
print(f"Invalid input. Must be one of: yes/no/y/n or empty for no")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MLX Model Cache.")
|
||||
parser.add_argument(
|
||||
"--scan",
|
||||
action="store_true",
|
||||
help="Scan Hugging Face cache for mlx models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
action="store_true",
|
||||
help="Delete models matching the given pattern.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
type=str,
|
||||
help="Model repos contain the pattern.",
|
||||
default="mlx",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.scan:
|
||||
print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".')
|
||||
hf_cache_info = scan_cache_dir()
|
||||
print(
|
||||
tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
repo.repo_type,
|
||||
"{:>12}".format(repo.size_on_disk_str),
|
||||
repo.nb_files,
|
||||
repo.last_accessed_str,
|
||||
repo.last_modified_str,
|
||||
str(repo.repo_path),
|
||||
]
|
||||
for repo in sorted(
|
||||
hf_cache_info.repos, key=lambda repo: repo.repo_path
|
||||
)
|
||||
if args.pattern in repo.repo_id
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"REPO TYPE",
|
||||
"SIZE ON DISK",
|
||||
"NB FILES",
|
||||
"LAST_ACCESSED",
|
||||
"LAST_MODIFIED",
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if args.delete:
|
||||
print(f'Deleting models matching pattern "{args.pattern}"')
|
||||
hf_cache_info = scan_cache_dir()
|
||||
|
||||
repos = [
|
||||
repo
|
||||
for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
|
||||
if args.pattern in repo.repo_id
|
||||
]
|
||||
if repos:
|
||||
print("\nFound the following models:")
|
||||
print(
|
||||
tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
repo.size_on_disk_str, # Added size information
|
||||
str(repo.repo_path),
|
||||
]
|
||||
for repo in repos
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"SIZE", # Added size header
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
confirmed = ask_for_confirmation(
|
||||
"\nAre you sure you want to delete these models?"
|
||||
)
|
||||
if confirmed:
|
||||
for model_info in repos:
|
||||
print(f"\nDeleting {model_info.repo_id}...")
|
||||
for revision in sorted(
|
||||
model_info.revisions, key=lambda revision: revision.commit_hash
|
||||
):
|
||||
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
|
||||
strategy.execute()
|
||||
print("\nModel(s) deleted successfully.")
|
||||
else:
|
||||
print("\nDeletion cancelled - no changes made.")
|
||||
else:
|
||||
print(f'No models found matching pattern "{args.pattern}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,172 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import yaml
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
save_config,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Merge multiple models.")
|
||||
|
||||
parser.add_argument("--config", type=str, help="Path to the YAML config.")
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_merged_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def slerp(t, w1, w2, eps=1e-5):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
|
||||
Args:
|
||||
t (float): Interpolation weight in [0.0, 1.0]
|
||||
w1 (mx.array): First input
|
||||
w2 (mx.array): Second input
|
||||
eps (float): Constant for numerical stability
|
||||
Returns:
|
||||
mx.array: Interpolated result
|
||||
"""
|
||||
t = float(t)
|
||||
if t == 0:
|
||||
return w1
|
||||
elif t == 1:
|
||||
return w2
|
||||
# Normalize
|
||||
v1 = w1 / mx.linalg.norm(w1)
|
||||
v2 = w2 / mx.linalg.norm(w2)
|
||||
# Angle
|
||||
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
|
||||
theta = mx.arccos(dot)
|
||||
sin_theta = mx.sin(theta + eps)
|
||||
s1 = mx.sin(theta * (1 - t)) / sin_theta
|
||||
s2 = mx.sin(theta * t) / sin_theta
|
||||
return s1 * w1 + s2 * w2
|
||||
|
||||
|
||||
def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
||||
method = config.get("method", None)
|
||||
if method != "slerp":
|
||||
raise ValueError(f"Merge method {method} not supported")
|
||||
|
||||
num_layers = len(model.layers)
|
||||
|
||||
def unpack_values(vals):
|
||||
if isinstance(vals, (int, float)):
|
||||
return np.full(num_layers, vals)
|
||||
bins = len(vals) - 1
|
||||
sizes = [num_layers // bins] * bins
|
||||
sizes[-1] = num_layers - sum(sizes[:-1])
|
||||
return np.concatenate(
|
||||
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
|
||||
)
|
||||
|
||||
param_list = config["parameters"]["t"]
|
||||
params = {}
|
||||
filter_keys = set()
|
||||
for pl in param_list[:-1]:
|
||||
params[pl["filter"]] = unpack_values(pl["value"])
|
||||
filter_keys.add(pl["filter"])
|
||||
default = unpack_values(param_list[-1]["value"])
|
||||
|
||||
for e in range(num_layers):
|
||||
bl = base_model.layers[e]
|
||||
l = model.layers[e]
|
||||
base_weights = bl.parameters()
|
||||
weights = l.parameters()
|
||||
for k, w1 in base_weights.items():
|
||||
w2 = weights[k]
|
||||
t = params.get(k, default)[e]
|
||||
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
|
||||
base_model.update(base_weights)
|
||||
|
||||
|
||||
def merge(
|
||||
config: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
upload_repo: Optional[str] = None,
|
||||
):
|
||||
with open(config, "r") as fid:
|
||||
merge_conf = yaml.safe_load(fid)
|
||||
print("[INFO] Loading")
|
||||
|
||||
model_paths = merge_conf.get("models", [])
|
||||
if len(model_paths) < 2:
|
||||
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
|
||||
|
||||
# Load all models
|
||||
base_hf_path = model_paths[0]
|
||||
base_path = get_model_path(base_hf_path)
|
||||
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||
models = []
|
||||
for mp in model_paths[1:]:
|
||||
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||
base_type = base_config["model_type"]
|
||||
model_type = model_config["model_type"]
|
||||
if base_type != model_type:
|
||||
raise ValueError(
|
||||
f"Can only merge models of the same type,"
|
||||
f" but got {base_type} and {model_type}."
|
||||
)
|
||||
models.append(model)
|
||||
|
||||
# Merge models into base model
|
||||
for m in models:
|
||||
merge_models(base_model, m, merge_conf)
|
||||
|
||||
# Save base model
|
||||
mlx_path = Path(mlx_path)
|
||||
weights = dict(tree_flatten(base_model.parameters()))
|
||||
del models, base_model
|
||||
save_weights(mlx_path, weights, donate_weights=True)
|
||||
py_files = glob.glob(str(base_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, mlx_path)
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
save_config(config, config_path=mlx_path / "config.json")
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
merge(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,120 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
from .cache import QuantizedKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArgs:
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
N: int,
|
||||
offset: int = 0,
|
||||
window_size: Optional[int] = None,
|
||||
lengths: Optional[mx.array] = None,
|
||||
):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
linds = linds[:, None]
|
||||
rinds = rinds[None]
|
||||
mask = linds >= rinds
|
||||
if window_size is not None:
|
||||
mask = mask & (linds <= rinds + window_size)
|
||||
if lengths is not None:
|
||||
lengths = lengths[:, None, None, None]
|
||||
mask = mask & (rinds < lengths)
|
||||
return mask
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
window_size = None
|
||||
offset = 0
|
||||
if cache is not None and cache[0] is not None:
|
||||
c = cache[0]
|
||||
if hasattr(c, "max_size"):
|
||||
offset = min(c.max_size, c.offset)
|
||||
window_size = c.max_size
|
||||
else:
|
||||
offset = c.offset
|
||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||
else:
|
||||
mask = None
|
||||
return mask
|
||||
|
||||
|
||||
def quantized_scaled_dot_product_attention(
|
||||
queries: mx.array,
|
||||
q_keys: tuple[mx.array, mx.array, mx.array],
|
||||
q_values: tuple[mx.array, mx.array, mx.array],
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
group_size: int = 64,
|
||||
bits: int = 8,
|
||||
) -> mx.array:
|
||||
B, n_q_heads, L, D = queries.shape
|
||||
n_kv_heads = q_keys[0].shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
queries *= scale
|
||||
|
||||
if n_repeats > 1:
|
||||
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
||||
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
||||
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
||||
|
||||
scores = mx.quantized_matmul(
|
||||
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
|
||||
)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
out = mx.quantized_matmul(
|
||||
scores, *q_values, transpose=False, group_size=group_size, bits=bits
|
||||
)
|
||||
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, (B, n_q_heads, L, D))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache,
|
||||
scale: float,
|
||||
mask: Optional[mx.array],
|
||||
) -> mx.array:
|
||||
if isinstance(cache, QuantizedKVCache):
|
||||
return quantized_scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
group_size=cache.group_size,
|
||||
bits=cache.bits,
|
||||
)
|
||||
else:
|
||||
return mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=scale, mask=mask
|
||||
)
|
||||
@@ -1,438 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(
|
||||
model: nn.Module,
|
||||
max_kv_size: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
|
||||
This function will defer the cache construction to the model if it has a
|
||||
``make_cache`` method, otherwise it will make a default KV cache.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
"""
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
num_layers = len(model.layers)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
return [KVCache() for _ in range(num_layers)]
|
||||
|
||||
|
||||
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||
"""
|
||||
Save a pre-computed prompt cache to a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
cache (List[Any]): The model state.
|
||||
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||
state.
|
||||
"""
|
||||
cache_data = [c.state for c in cache]
|
||||
cache_info = [c.meta_state for c in cache]
|
||||
cache_data = dict(tree_flatten(cache_data))
|
||||
cache_classes = [type(c).__name__ for c in cache]
|
||||
cache_metadata = [cache_info, metadata, cache_classes]
|
||||
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
|
||||
def load_prompt_cache(file_name, return_metadata=False):
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
return_metadata (bool): Whether or not to return metadata.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
the metadata if requested.
|
||||
"""
|
||||
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
||||
arrays = tree_unflatten(list(arrays.items()))
|
||||
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
||||
info, metadata, classes = cache_metadata
|
||||
cache = [globals()[c]() for c in classes]
|
||||
for c, state, meta_state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
c.meta_state = meta_state
|
||||
if return_metadata:
|
||||
return cache, metadata
|
||||
return cache
|
||||
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
return all(c.is_trimmable() for c in cache)
|
||||
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
|
||||
This function will trim the cache if possible (in-place) and return the
|
||||
number of tokens that were trimmed.
|
||||
|
||||
Args:
|
||||
cache (List[Any]): The model's cache.
|
||||
num_tokens (int): The number of tokens to trim.
|
||||
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
||||
return 0
|
||||
return [c.trim(num_tokens) for c in cache][0]
|
||||
|
||||
|
||||
class _BaseCache:
|
||||
@property
|
||||
def state(self):
|
||||
return []
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
if v is not None and v:
|
||||
raise ValueError("This cache has no state but a state was set.")
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return ""
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
if v is not None and v:
|
||||
raise ValueError("This cache has no meta_state but a meta_state was set.")
|
||||
|
||||
def is_trimmable(self):
|
||||
return False
|
||||
|
||||
|
||||
class QuantizedKVCache(_BaseCache):
|
||||
def __init__(self, group_size: int = 64, bits: int = 8):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||
v_head_dim = values.shape[-1]
|
||||
prev = self.offset
|
||||
|
||||
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
||||
el_per_int = 8 * mx.uint32.size // self.bits
|
||||
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
||||
shape = (B, n_kv_heads, new_steps)
|
||||
|
||||
def init_quant(dim):
|
||||
return (
|
||||
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
|
||||
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||
)
|
||||
|
||||
def expand_quant(x):
|
||||
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
|
||||
return mx.concatenate([x, new_x], axis=-2)
|
||||
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys, self.values = tree_map(
|
||||
lambda x: x[..., :prev, :], (self.keys, self.values)
|
||||
)
|
||||
|
||||
self.keys, self.values = tree_map(
|
||||
expand_quant, (self.keys, self.values)
|
||||
)
|
||||
else:
|
||||
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
|
||||
|
||||
self.offset += num_steps
|
||||
|
||||
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||
for i in range(len(self.keys)):
|
||||
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||
self.values[i][..., prev : self.offset, :] = values[i]
|
||||
|
||||
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset == self.keys[0].shape[2]:
|
||||
return self.keys, self.values
|
||||
else:
|
||||
return tree_map(
|
||||
lambda x: x[..., : self.offset, :], (self.keys, self.values)
|
||||
)
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
self.step, self.offset, self.group_size, self.bits = map(int, v)
|
||||
|
||||
def is_trimmable(self):
|
||||
return True
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
B, n_kv_heads, _, k_head_dim = keys.shape
|
||||
v_head_dim = values.shape[3]
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
|
||||
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
self.keys[..., prev : self.offset, :] = keys
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset == self.keys.shape[2]:
|
||||
return self.keys, self.values
|
||||
else:
|
||||
return (
|
||||
self.keys[..., : self.offset, :],
|
||||
self.values[..., : self.offset, :],
|
||||
)
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
self.offset = self.keys.shape[2]
|
||||
|
||||
def is_trimmable(self):
|
||||
return True
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||
quant_cache.offset = self.offset
|
||||
if self.keys is not None:
|
||||
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
||||
quant_cache.values = mx.quantize(
|
||||
self.values, group_size=group_size, bits=bits
|
||||
)
|
||||
return quant_cache
|
||||
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
|
||||
def __init__(self, max_size=None, keep=0, step=256):
|
||||
self.keep = keep
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
to_cat = []
|
||||
if trim_size > 0:
|
||||
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||
else:
|
||||
to_cat = [v]
|
||||
if append is not None:
|
||||
to_cat.append(append)
|
||||
return mx.concatenate(to_cat, axis=2)
|
||||
|
||||
def _temporal_order(self, v):
|
||||
"""
|
||||
Rearrange the cache into temporal order, slicing off the end if unused.
|
||||
"""
|
||||
if self._idx == v.shape[2]:
|
||||
return v
|
||||
elif self._idx < self.offset:
|
||||
return mx.concatenate(
|
||||
[
|
||||
v[..., : self.keep, :],
|
||||
v[..., self._idx :, :],
|
||||
v[..., self.keep : self._idx, :],
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
else:
|
||||
return v[..., : self._idx, :]
|
||||
|
||||
def _update_concat(self, keys, values):
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
# Put the keys/values in temporal order to
|
||||
# preserve context
|
||||
self.keys = self._temporal_order(self.keys)
|
||||
self.values = self._temporal_order(self.values)
|
||||
|
||||
# The largest size is self.max_size + S to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self._idx - self.max_size
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += keys.shape[2]
|
||||
self._idx = self.keys.shape[2]
|
||||
return self.keys, self.values
|
||||
|
||||
def _update_in_place(self, keys, values):
|
||||
# May not have hit the max size yet, so potentially
|
||||
# keep growing the cache
|
||||
B, n_kv_heads, S, k_head_dim = keys.shape
|
||||
prev = self.offset
|
||||
if self.keys is None or (
|
||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||
):
|
||||
v_head_dim = values.shape[3]
|
||||
new_size = min(self.step, self.max_size - prev)
|
||||
k_shape = (B, n_kv_heads, new_size, k_head_dim)
|
||||
v_shape = (B, n_kv_heads, new_size, v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
self._idx = prev
|
||||
|
||||
# Trim if needed
|
||||
trim_size = self.keys.shape[2] - self.max_size
|
||||
if trim_size > 0:
|
||||
self.keys = self._trim(trim_size, self.keys)
|
||||
self.values = self._trim(trim_size, self.values)
|
||||
self._idx = self.max_size
|
||||
|
||||
# Rotate
|
||||
if self._idx == self.max_size:
|
||||
self._idx = self.keep
|
||||
|
||||
# Assign
|
||||
self.keys[..., self._idx : self._idx + S, :] = keys
|
||||
self.values[..., self._idx : self._idx + S, :] = values
|
||||
self.offset += S
|
||||
self._idx += S
|
||||
|
||||
# If the buffer is not full, slice off the end
|
||||
if self.offset < self.max_size:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
return self.keys, self.values
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
if keys.shape[2] == 1:
|
||||
return self._update_in_place(keys, values)
|
||||
return self._update_concat(keys, values)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset < self.keys.shape[2]:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
else:
|
||||
return self.keys, self.values
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return tuple(
|
||||
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
||||
)
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
||||
int,
|
||||
v,
|
||||
)
|
||||
|
||||
def is_trimmable(self):
|
||||
return self.offset < self.max_size
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
self._idx -= n
|
||||
return n
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
raise NotImplementedError("RotatingKVCache Quantization NYI")
|
||||
|
||||
|
||||
class MambaCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.cache = [None, None]
|
||||
|
||||
def __setitem__(self, idx, value):
|
||||
self.cache[idx] = value
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.cache[idx]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.cache = v
|
||||
@@ -1,195 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 8192
|
||||
num_hidden_layers: int = 40
|
||||
intermediate_size: int = 22528
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 64
|
||||
rope_theta: float = 8000000.0
|
||||
vocab_size: int = 256000
|
||||
layer_norm_eps: float = 1e-05
|
||||
logit_scale: float = 0.0625
|
||||
attention_bias: bool = False
|
||||
layer_norm_bias: bool = False
|
||||
use_qk_norm: bool = False
|
||||
|
||||
|
||||
class LayerNorm2D(nn.Module):
|
||||
|
||||
def __init__(self, d1, d2, eps):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((d1, d2))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // args.num_attention_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
attetion_bias = args.attention_bias
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||
|
||||
self.use_qk_norm = args.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
|
||||
self.k_norm = LayerNorm2D(
|
||||
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.n_heads, -1)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1)
|
||||
if self.use_qk_norm:
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = CohereModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,206 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .cache import KVCache, RotatingKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 4096
|
||||
head_dim: int = 128
|
||||
num_hidden_layers: int = 32
|
||||
intermediate_size: int = 14336
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
rope_theta: float = 50000.0
|
||||
vocab_size: int = 256000
|
||||
layer_norm_eps: float = 1e-05
|
||||
logit_scale: float = 0.0625
|
||||
attention_bias: bool = False
|
||||
layer_norm_bias: bool = False
|
||||
sliding_window: int = 4096
|
||||
sliding_window_pattern: int = 4
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
if (head_dim * n_heads) != dim:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
|
||||
f" and `num_heads`: {n_heads})."
|
||||
)
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
attetion_bias = args.attention_bias
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Apply RoPE only if sliding window is enabled
|
||||
if self.use_sliding_window:
|
||||
if cache is None:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
else:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
if self.use_sliding_window and mask is not None:
|
||||
key_len = keys.shape[-2]
|
||||
if mask.shape[-1] != key_len:
|
||||
mask = mask[..., -key_len:]
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args, layer_idx)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args, layer_idx=i)
|
||||
for i in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
if mask is None:
|
||||
j = self.args.sliding_window_pattern
|
||||
mask = create_attention_mask(h, cache[j - 1 : j])
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = CohereModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
def make_cache(self):
|
||||
caches = []
|
||||
for i in range(self.args.num_hidden_layers):
|
||||
if (
|
||||
i % self.args.sliding_window_pattern
|
||||
== self.args.sliding_window_pattern - 1
|
||||
):
|
||||
caches.append(KVCache())
|
||||
else:
|
||||
caches.append(
|
||||
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
|
||||
)
|
||||
return caches
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,254 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
d_model: int
|
||||
ffn_config: dict
|
||||
attn_config: dict
|
||||
n_layers: int
|
||||
n_heads: int
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_heads = args.n_heads
|
||||
self.d_model = args.d_model
|
||||
self.head_dim = args.d_model // args.n_heads
|
||||
self.num_key_value_heads = args.attn_config["kv_n_heads"]
|
||||
self.clip_qkv = args.attn_config["clip_qkv"]
|
||||
self.rope_theta = args.attn_config["rope_theta"]
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.Wqkv = nn.Linear(
|
||||
args.d_model,
|
||||
(self.num_key_value_heads * 2 + self.num_heads) * self.head_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.out_proj = nn.Linear(args.d_model, args.d_model, bias=False)
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = mx.clip(qkv, a_min=-self.clip_qkv, a_max=self.clip_qkv)
|
||||
splits = [self.d_model, self.d_model + self.head_dim * self.num_key_value_heads]
|
||||
queries, keys, values = mx.split(qkv, splits, axis=-1)
|
||||
|
||||
B, L, D = x.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class NormAttnNorm(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.norm_1 = nn.LayerNorm(args.d_model, bias=False)
|
||||
self.norm_2 = nn.LayerNorm(args.d_model, bias=False)
|
||||
self.attn = Attention(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
||||
x = h + x
|
||||
return x, self.norm_2(x)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, d_model: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.v1 = nn.Linear(d_model, ffn_dim, bias=False)
|
||||
self.w1 = nn.Linear(d_model, ffn_dim, bias=False)
|
||||
self.w2 = nn.Linear(ffn_dim, d_model, bias=False)
|
||||
self.act_fn = nn.silu
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
current_hidden_states = self.act_fn(self.w1(x)) * self.v1(x)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class Router(nn.Module):
|
||||
def __init__(self, d_model: int, num_experts: int):
|
||||
super().__init__()
|
||||
self.layer = nn.Linear(d_model, num_experts, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class SparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.d_model = args.d_model
|
||||
self.ffn_dim = args.ffn_config["ffn_hidden_size"]
|
||||
self.num_experts = args.ffn_config["moe_num_experts"]
|
||||
self.num_experts_per_tok = args.ffn_config["moe_top_k"]
|
||||
|
||||
self.router = Router(self.d_model, self.num_experts)
|
||||
self.experts = [
|
||||
MLP(self.d_model, self.ffn_dim) for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.router(x)
|
||||
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
|
||||
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True)
|
||||
scores = scores.astype(x.dtype)
|
||||
|
||||
if self.training:
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt)
|
||||
y = mx.stack(y, axis=0)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.ffn = SparseMoeBlock(args)
|
||||
self.norm_attn_norm = NormAttnNorm(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r, h = self.norm_attn_norm(x, mask, cache)
|
||||
out = self.ffn(h) + r
|
||||
return out
|
||||
|
||||
|
||||
class DBRX(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = args.vocab_size
|
||||
self.wte = nn.Embedding(args.vocab_size, args.d_model)
|
||||
self.blocks = [DecoderLayer(args=args) for _ in range(args.n_layers)]
|
||||
self.norm_f = nn.LayerNorm(args.d_model, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
||||
for layer, c in zip(self.blocks, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm_f(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.transformer = DBRX(args)
|
||||
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.blocks
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Split experts into sub matrices
|
||||
num_experts = self.args.ffn_config["moe_num_experts"]
|
||||
dim = self.args.ffn_config["ffn_hidden_size"]
|
||||
|
||||
pattern = "experts.mlp"
|
||||
new_weights = {k: v for k, v in weights.items() if pattern not in k}
|
||||
for k, v in weights.items():
|
||||
if pattern in k:
|
||||
experts = [
|
||||
(k.replace(".mlp", f".{e}") + ".weight", sv)
|
||||
for e, sv in enumerate(mx.split(v, num_experts, axis=0))
|
||||
]
|
||||
if k.endswith("w2"):
|
||||
experts = [(s, sv.T) for s, sv in experts]
|
||||
new_weights.update(experts)
|
||||
return new_weights
|
||||
@@ -1,261 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Optional[Dict] = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
class DeepseekAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
attention_bias = getattr(config, "attention_bias", False)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.num_attention_heads * self.head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.num_attention_heads * self.head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
|
||||
rope_scale = 1.0
|
||||
if config.rope_scaling and config.rope_scaling["type"] == "linear":
|
||||
assert isinstance(config.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / config.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
base=config.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = hidden_size or config.hidden_size
|
||||
self.intermediate_size = intermediate_size or config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = nn.silu
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
scores = mx.softmax(gates, axis=-1, precise=True)
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekMoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekMLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekAttention(config)
|
||||
self.mlp = (
|
||||
DeepseekMoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekMLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class DeepseekModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for m in ["gate_proj", "down_proj", "up_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,462 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek_v2"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
routed_scaling_factor: float = 1.0
|
||||
kv_lora_rank: int = 512
|
||||
q_lora_rank: int = 1536
|
||||
qk_rope_head_dim: int = 64
|
||||
v_head_dim: int = 128
|
||||
qk_nope_head_dim: int = 128
|
||||
topk_method: str = "gready"
|
||||
n_group: Optional[int] = None
|
||||
topk_group: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Dict = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
def yarn_find_correction_dim(
|
||||
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
|
||||
def yarn_find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||
if min_val == max_val:
|
||||
max_val += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
|
||||
return mx.clip(linear_func, 0, 1)
|
||||
|
||||
|
||||
class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
scaling_factor=1.0,
|
||||
original_max_position_embeddings=4096,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
mscale=1,
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
freq_inter = scaling_factor * base ** (
|
||||
mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
)
|
||||
low, high = yarn_find_correction_range(
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
dim,
|
||||
base,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
self._freqs = (freq_inter * freq_extra) / (
|
||||
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||
)
|
||||
|
||||
def __call__(self, x, offset=0):
|
||||
if self.mscale != 1.0:
|
||||
x = self.mscale * x
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
x.shape[-1],
|
||||
traditional=True,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
|
||||
self.scale = self.q_head_dim**-0.5
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads
|
||||
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV2YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(x)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
|
||||
|
||||
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
|
||||
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
|
||||
compressed_kv = self.kv_a_proj_with_mqa(x)
|
||||
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
|
||||
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
)
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
|
||||
scores = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
if self.topk_method == "group_limited_greedy":
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = scores.max(axis=-1, keepdims=True)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||
scores = mx.put_along_axis(
|
||||
scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
|
||||
)
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV2Attention(config)
|
||||
self.mlp = (
|
||||
DeepseekV2MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV2MLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekV2DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.start_idx = 0
|
||||
self.end_idx = len(self.layers)
|
||||
self.num_layers = self.end_idx
|
||||
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
|
||||
def pipeline(self, group):
|
||||
# Split layers in reverse so rank=0 gets the last layers and
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||
if self.pipeline_rank < extra:
|
||||
layers_per_rank += 1
|
||||
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.num_layers = layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
self.num_layers = len(self.layers) - self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
# Hack to avoid time-outs during prompt-processing
|
||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * self.num_layers
|
||||
|
||||
# Receive from the previous process in the pipeline
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(
|
||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||
)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV2Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||
@@ -1,506 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek_v3"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
routed_scaling_factor: float = 1.0
|
||||
kv_lora_rank: int = 512
|
||||
q_lora_rank: int = 1536
|
||||
qk_rope_head_dim: int = 64
|
||||
v_head_dim: int = 128
|
||||
qk_nope_head_dim: int = 128
|
||||
topk_method: str = "noaux_tc"
|
||||
scoring_func: str = "sigmoid"
|
||||
norm_topk_prob: bool = True
|
||||
n_group: Optional[int] = None
|
||||
topk_group: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Dict = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
def yarn_find_correction_dim(
|
||||
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
|
||||
def yarn_find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||
if min_val == max_val:
|
||||
max_val += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
|
||||
return mx.clip(linear_func, 0, 1)
|
||||
|
||||
|
||||
class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
scaling_factor=1.0,
|
||||
original_max_position_embeddings=4096,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
mscale=1,
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
freq_inter = scaling_factor * base ** (
|
||||
mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
)
|
||||
low, high = yarn_find_correction_range(
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
dim,
|
||||
base,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
self._freqs = (freq_inter * freq_extra) / (
|
||||
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||
)
|
||||
|
||||
def __call__(self, x, offset=0):
|
||||
if self.mscale != 1.0:
|
||||
x = self.mscale * x
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
x.shape[-1],
|
||||
traditional=True,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
# A clipped silu to prevent fp16 from overflowing
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def clipped_silu(x):
|
||||
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
|
||||
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
|
||||
self.scale = self.q_head_dim**-0.5
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads
|
||||
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
if self.config.rope_scaling is not None:
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV3YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
else:
|
||||
self.rope = nn.RoPE(
|
||||
dims=self.qk_rope_head_dim,
|
||||
base=self.rope_theta,
|
||||
traditional=True,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(x)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
|
||||
|
||||
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
|
||||
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
|
||||
compressed_kv = self.kv_a_proj_with_mqa(x)
|
||||
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
|
||||
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
)
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
@mx.compile
|
||||
def group_expert_select(
|
||||
gates,
|
||||
e_score_correction_bias,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
routed_scaling_factor,
|
||||
norm_topk_prob,
|
||||
):
|
||||
|
||||
k = top_k
|
||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||
scores = scores + e_score_correction_bias
|
||||
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
|
||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
|
||||
k = n_group - topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
|
||||
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
|
||||
scores = mx.flatten(scores, -2, -1)
|
||||
|
||||
k = top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
if top_k > 1 and norm_topk_prob:
|
||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||
scores = scores / denominator
|
||||
scores = scores * routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||
assert config.topk_method == "noaux_tc", "Unsupported topk method."
|
||||
|
||||
def __call__(self, x):
|
||||
return group_expert_select(
|
||||
x @ self.weight.T,
|
||||
self.e_score_correction_bias,
|
||||
self.top_k,
|
||||
self.n_group,
|
||||
self.topk_group,
|
||||
self.routed_scaling_factor,
|
||||
self.norm_topk_prob,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size,
|
||||
config.moe_intermediate_size,
|
||||
config.n_routed_experts,
|
||||
activation=clipped_silu,
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
self.mlp = (
|
||||
DeepseekV3MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV3MLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
return h + r
|
||||
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekV3DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.start_idx = 0
|
||||
self.end_idx = len(self.layers)
|
||||
self.num_layers = self.end_idx
|
||||
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
|
||||
def pipeline(self, group):
|
||||
# Split layers in reverse so rank=0 gets the last layers and
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||
if self.pipeline_rank < extra:
|
||||
layers_per_rank += 1
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
self.num_layers = len(self.layers) - self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
# Hack to avoid time-outs during prompt-processing
|
||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * self.num_layers
|
||||
|
||||
# Receive from the previous process in the pipeline
|
||||
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(
|
||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||
)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV3Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
# Remove multi-token prediction layer and any unused precomputed rotary freqs
|
||||
return {
|
||||
k: v
|
||||
for k, v in weights.items()
|
||||
if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||
@@ -1,166 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
vocab_size: int
|
||||
rope_theta: float
|
||||
layer_norm_epsilon: float
|
||||
num_key_value_heads: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.head_dim = head_dim = args.head_dim or (dim // n_heads)
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
args.rope_traditional,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
q = self.rope(q, offset=cache.offset)
|
||||
k = self.rope(k, offset=cache.offset)
|
||||
k, v = cache.update_and_fetch(k, v)
|
||||
else:
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
out = scaled_dot_product_attention(
|
||||
q, k, v, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.attention = AttentionModule(args)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
|
||||
self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = x + self.attn.attention(self.ln_1(x), mask, cache)
|
||||
out = h + self.mlp(self.ln_2(h))
|
||||
return out
|
||||
|
||||
|
||||
class ExaoneModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_layers)]
|
||||
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.ln_f(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = ExaoneModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
@@ -1,178 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
head_dim: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = GemmaModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,203 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
head_dim: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
attn_logit_softcapping: float = 50.0
|
||||
final_logit_softcapping: float = 30.0
|
||||
query_pre_attn_scalar: float = 144.0
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
|
||||
self.scale = 1.0 / (args.query_pre_attn_scalar**0.5)
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.attn_logit_softcapping = args.attn_logit_softcapping
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries * self.scale
|
||||
|
||||
if self.repeats > 1:
|
||||
queries = queries.reshape(
|
||||
B, self.n_kv_heads, self.repeats, L, self.head_dim
|
||||
)
|
||||
keys = mx.expand_dims(keys, 2)
|
||||
values = mx.expand_dims(values, 2)
|
||||
|
||||
scores = queries @ keys.swapaxes(-1, -2)
|
||||
scores = mx.tanh(scores / self.attn_logit_softcapping)
|
||||
scores *= self.attn_logit_softcapping
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, precise=True, axis=-1)
|
||||
output = scores @ values
|
||||
if self.repeats > 1:
|
||||
output = output.reshape(B, self.n_heads, L, self.head_dim)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + self.post_attention_layernorm(r)
|
||||
r = self.mlp(self.pre_feedforward_layernorm(h))
|
||||
out = h + self.post_feedforward_layernorm(r)
|
||||
return out
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.final_logit_softcapping = args.final_logit_softcapping
|
||||
self.model = GemmaModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = mx.tanh(out / self.final_logit_softcapping)
|
||||
out = out * self.final_logit_softcapping
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,238 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .cache import KVCache, RotatingKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 1152
|
||||
num_hidden_layers: int = 26
|
||||
intermediate_size: int = 6912
|
||||
num_attention_heads: int = 4
|
||||
head_dim: int = 256
|
||||
rms_norm_eps: float = 1.0e-6
|
||||
vocab_size: int = 262144
|
||||
num_key_value_heads: int = 1
|
||||
rope_global_base_freq: float = 1_000_000.0
|
||||
rope_local_base_freq: float = 10_000.0
|
||||
rope_traditional: bool = False
|
||||
query_pre_attn_scalar: float = 256
|
||||
sliding_window: int = 512
|
||||
sliding_window_pattern: int = 6
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.scale = args.query_pre_attn_scalar**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
||||
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern != 0
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=(
|
||||
args.rope_local_base_freq
|
||||
if self.is_sliding
|
||||
else args.rope_global_base_freq
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Sliding window
|
||||
if mask is not None and mask.shape[-1] != keys.shape[-2]:
|
||||
mask = mask[..., -keys.shape[-2] :]
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args, layer_idx)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + self.post_attention_layernorm(r)
|
||||
r = self.mlp(self.pre_feedforward_layernorm(h))
|
||||
out = h + self.post_feedforward_layernorm(r)
|
||||
return out
|
||||
|
||||
|
||||
class Gemma3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args, layer_idx=layer_idx)
|
||||
for layer_idx in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
|
||||
h = self.embed_tokens(inputs)
|
||||
h *= mx.array(self.args.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
if mask is None:
|
||||
j = self.args.sliding_window_pattern
|
||||
full_mask = create_attention_mask(h, cache[j - 1 : j])
|
||||
sliding_window_mask = create_attention_mask(h, cache)
|
||||
|
||||
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||
is_sliding = (
|
||||
i % self.args.sliding_window_pattern
|
||||
== self.args.sliding_window_pattern - 1
|
||||
)
|
||||
|
||||
if mask is None and is_sliding:
|
||||
mask = sliding_window_mask
|
||||
elif mask is None:
|
||||
mask = full_mask
|
||||
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Gemma3Model(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
def make_cache(self):
|
||||
caches = []
|
||||
for i in range(self.args.num_hidden_layers):
|
||||
if (
|
||||
i % self.args.sliding_window_pattern
|
||||
== self.args.sliding_window_pattern - 1
|
||||
):
|
||||
caches.append(KVCache())
|
||||
else:
|
||||
caches.append(
|
||||
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
|
||||
)
|
||||
return caches
|
||||
@@ -1,201 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
n_ctx: int
|
||||
n_embd: int
|
||||
n_head: int
|
||||
n_layer: int
|
||||
n_positions: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.n_head
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
|
||||
|
||||
self.n_embd = args.n_embd
|
||||
self.n_head = args.n_head
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.c_attn(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.n_embd = args.n_embd
|
||||
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
|
||||
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = args.n_head
|
||||
self.n_embd = args.n_embd
|
||||
self.layer_norm_epsilon = args.layer_norm_epsilon
|
||||
self.attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.ln_1 = nn.LayerNorm(
|
||||
self.n_embd,
|
||||
eps=self.layer_norm_epsilon,
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.ln_2(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_embd = args.n_embd
|
||||
self.n_positions = args.n_positions
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layer = args.n_layer
|
||||
self.layer_norm_epsilon = args.layer_norm_epsilon
|
||||
assert self.vocab_size > 0
|
||||
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
|
||||
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
|
||||
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
_, L = inputs.shape
|
||||
|
||||
hidden_states = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if hidden_states.shape[1] > 1:
|
||||
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
return self.ln_f(hidden_states)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GPT2Model(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.wte.as_linear(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for i in range(self.args.n_layer):
|
||||
if f"h.{i}.attn.bias" in weights:
|
||||
del weights[f"h.{i}.attn.bias"]
|
||||
if f"h.{i}.attn.c_attn.weight" in weights:
|
||||
weights[f"h.{i}.attn.c_attn.weight"] = weights[
|
||||
f"h.{i}.attn.c_attn.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.attn.c_proj.weight" in weights:
|
||||
weights[f"h.{i}.attn.c_proj.weight"] = weights[
|
||||
f"h.{i}.attn.c_proj.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.mlp.c_fc.weight" in weights:
|
||||
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
|
||||
f"h.{i}.mlp.c_fc.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.mlp.c_proj.weight" in weights:
|
||||
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
|
||||
f"h.{i}.mlp.c_proj.weight"
|
||||
].transpose(1, 0)
|
||||
for weight in weights:
|
||||
if not weight.startswith("model."):
|
||||
new_weights[f"model.{weight}"] = weights[weight]
|
||||
else:
|
||||
new_weights[weight] = weights[weight]
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
@@ -1,189 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
n_embd: int
|
||||
n_layer: int
|
||||
n_inner: int
|
||||
n_head: int
|
||||
n_positions: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
multi_query: bool = True
|
||||
attention_bias: bool = True
|
||||
mlp_bias: bool = True
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = 1 if self.multi_query else self.n_head
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim = args.n_embd
|
||||
self.n_heads = n_heads = args.n_head
|
||||
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
|
||||
|
||||
self.head_dim = head_dim = dim // n_heads
|
||||
|
||||
self.kv_dim = n_kv_heads * head_dim
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
|
||||
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.c_attn(x)
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.n_embd
|
||||
hidden_dim = args.n_inner
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_head = args.n_head
|
||||
self.n_embd = args.n_embd
|
||||
self.attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.ln_2(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
assert self.vocab_size > 0
|
||||
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
|
||||
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
B, L = inputs.shape
|
||||
|
||||
hidden_states = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if mask is not None and hidden_states.shape[1] > 1:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
position_ids = mx.array(np.arange(L))
|
||||
else:
|
||||
position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L))
|
||||
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
return self.ln_f(hidden_states)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = GPTBigCodeModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
@@ -1,219 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
# Based on the transformers implementation at:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
max_position_embeddings: int
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
layer_norm_eps: float
|
||||
vocab_size: int
|
||||
rotary_emb_base: int
|
||||
rotary_pct: float
|
||||
num_key_value_heads: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
args.hidden_size % args.num_attention_heads == 0
|
||||
), "hidden_size must be divisible by num_attention_heads"
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
dims=int(self.head_dim * args.rotary_pct),
|
||||
traditional=False,
|
||||
base=args.rotary_emb_base,
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = nn.Linear(
|
||||
self.hidden_size, 3 * self.hidden_size, bias=True
|
||||
)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.query_key_value(x)
|
||||
|
||||
new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim)
|
||||
qkv = qkv.reshape(*new_qkv_shape)
|
||||
|
||||
queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)]
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
|
||||
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
# gelu_approx corresponds to FastGELUActivation in transformers.
|
||||
return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.layer_norm_eps = args.layer_norm_eps
|
||||
self.attention = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
self.hidden_size,
|
||||
eps=self.layer_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
self.hidden_size, eps=self.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
# NeoX runs attention and feedforward network in parallel.
|
||||
attn = self.attention(self.input_layernorm(x), mask, cache)
|
||||
ffn = self.mlp(self.post_attention_layernorm(x))
|
||||
out = attn + ffn + residual
|
||||
return out
|
||||
|
||||
|
||||
class GPTNeoXModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
self.layer_norm_eps = args.layer_norm_eps
|
||||
assert self.vocab_size > 0
|
||||
self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size)
|
||||
self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)]
|
||||
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
_, L = inputs.shape
|
||||
|
||||
hidden_states = self.embed_in(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
out = self.final_layer_norm(hidden_states)
|
||||
out = self.embed_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GPTNeoXModel(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
|
||||
for w_key, w_value in weights.items():
|
||||
# Created through register_buffer in Pytorch, not needed here.
|
||||
ignore_suffixes = [
|
||||
".attention.bias",
|
||||
".attention.masked_bias",
|
||||
".attention.rotary_emb.inv_freq",
|
||||
]
|
||||
|
||||
skip_weight = False
|
||||
for ignored_suffix in ignore_suffixes:
|
||||
if w_key.endswith(ignored_suffix):
|
||||
skip_weight = True
|
||||
break
|
||||
|
||||
if skip_weight:
|
||||
continue
|
||||
|
||||
if not w_key.startswith("model."):
|
||||
w_key = f"model.{w_key}"
|
||||
|
||||
w_key = w_key.replace(".gpt_neox.layers.", ".h.")
|
||||
w_key = w_key.replace(".gpt_neox.", ".")
|
||||
|
||||
new_weights[w_key] = w_value
|
||||
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
@@ -1,195 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
logits_scaling: float
|
||||
attention_multiplier: float
|
||||
embedding_multiplier: float
|
||||
residual_multiplier: float
|
||||
max_position_embeddings: int
|
||||
num_key_value_heads: int
|
||||
attention_bias: bool
|
||||
mlp_bias: bool
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
|
||||
self.scale = args.attention_multiplier
|
||||
attention_bias = args.attention_bias
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
False,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.residual_multiplier = args.residual_multiplier
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r * self.residual_multiplier
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r * self.residual_multiplier
|
||||
return out
|
||||
|
||||
|
||||
class GraniteModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.embedding_multiplier = args.embedding_multiplier
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs) * self.embedding_multiplier
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GraniteModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.logits_scaling = args.logits_scaling
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out / self.logits_scaling
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,185 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
attention_bias: bool
|
||||
head_dim: int
|
||||
max_position_embeddings: int
|
||||
mlp_bias: bool
|
||||
model_type: str
|
||||
rope_theta: float
|
||||
tie_word_embeddings: bool
|
||||
|
||||
|
||||
class HeliumAttention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class HeliumMLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.intermediate_size = args.intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
||||
)
|
||||
self.up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
|
||||
)
|
||||
self.down_proj = nn.Linear(
|
||||
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class HeliumDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = HeliumAttention(args)
|
||||
self.mlp = HeliumMLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class HeliumModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
self.vocab_size = args.vocab_size
|
||||
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
|
||||
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
|
||||
self.model = HeliumModel(args)
|
||||
|
||||
self.vocab_size = args.vocab_size
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,321 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
attention_bias: bool
|
||||
moe_topk: int
|
||||
num_experts: int
|
||||
num_shared_expert: int
|
||||
use_mixed_mlp_moe: bool
|
||||
use_qk_norm: bool
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
use_cla: bool
|
||||
cla_share_factor: 2
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
|
||||
class DynamicNTKAlphaRoPE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
base: float = 10000,
|
||||
scaling_alpha: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
base = base * scaling_alpha ** (dims / (dims - 2))
|
||||
self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, kv_proj: bool, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
if kv_proj:
|
||||
self.k_proj = nn.Linear(
|
||||
dim, n_kv_heads * head_dim, bias=args.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
dim, n_kv_heads * head_dim, bias=args.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
||||
self.use_qk_norm = args.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
|
||||
self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
|
||||
|
||||
self.rope = DynamicNTKAlphaRoPE(
|
||||
head_dim,
|
||||
base=args.rope_theta,
|
||||
scaling_alpha=args.rope_scaling["alpha"],
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
kv_states=None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries = self.q_proj(x)
|
||||
if kv_states is None:
|
||||
keys, values = self.k_proj(x), self.v_proj(x)
|
||||
kv_states = keys, values
|
||||
else:
|
||||
keys, values = kv_states
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
offset = cache.offset if cache else 0
|
||||
queries = self.rope(queries, offset=offset)
|
||||
keys = self.rope(keys, offset=offset)
|
||||
if self.use_qk_norm:
|
||||
queries = self.query_layernorm(queries)
|
||||
keys = self.key_layernorm(keys)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), kv_states
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Gate(nn.Module):
|
||||
def __init__(self, dim, num_experts):
|
||||
super().__init__()
|
||||
self.wg = nn.Linear(dim, num_experts, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.wg(x)
|
||||
|
||||
|
||||
class MoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
intermediate_size = args.intermediate_size
|
||||
self.use_shared_mlp = args.use_mixed_mlp_moe
|
||||
|
||||
if args.use_mixed_mlp_moe:
|
||||
self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert)
|
||||
|
||||
self.num_experts = num_experts = args.num_experts
|
||||
self.top_k = args.moe_topk
|
||||
|
||||
self.gate = Gate(dim, num_experts)
|
||||
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
):
|
||||
gates = self.gate(x)
|
||||
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
if self.use_shared_mlp:
|
||||
shared_expert_output = self.shared_mlp(x)
|
||||
y = y + shared_expert_output
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs, kv_proj: bool):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(kv_proj, args)
|
||||
if args.num_experts == 1:
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
else:
|
||||
self.mlp = MoeBlock(args)
|
||||
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
):
|
||||
r, shared_kv_states = self.self_attn(
|
||||
self.input_layernorm(x), mask, cache, shared_kv_states
|
||||
)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, shared_kv_states
|
||||
|
||||
|
||||
class HunYuanModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
DecoderLayer(
|
||||
args=args,
|
||||
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
|
||||
)
|
||||
for i in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for i, (layer, c) in enumerate(zip(self.layers, cache)):
|
||||
if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
|
||||
shared_kv_states = None
|
||||
h, shared_kv_states = layer(h, mask, c, shared_kv_states)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = HunYuanModel(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.model.embed_tokens.as_linear(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
|
||||
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
|
||||
new_weights = {}
|
||||
D = self.args.hidden_size
|
||||
n_kv_heads = self.args.num_key_value_heads
|
||||
n_kv_groups = self.args.num_attention_heads // n_kv_heads
|
||||
head_dim = D // self.args.num_attention_heads
|
||||
for k, v in weights.items():
|
||||
if "qkv_proj" in k:
|
||||
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
|
||||
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
|
||||
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
|
||||
k_new = k.replace("qkv_proj", k_up)
|
||||
new_weights[k_new] = mx.flatten(v_new, 0, 2)
|
||||
elif "gate_and_up_proj" in k:
|
||||
splits = v.split(2, axis=0)
|
||||
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
|
||||
k_new = k.replace("gate_and_up_proj", k_up)
|
||||
new_weights[k_new] = v_new
|
||||
else:
|
||||
new_weights[k] = v
|
||||
weights = new_weights
|
||||
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,241 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
bias: bool = True
|
||||
max_position_embeddings: int = 32768
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
|
||||
)
|
||||
|
||||
|
||||
class DynamicNTKScalingRoPE(nn.Module):
|
||||
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_base = base
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.scale = scale
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.original_base * (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
else:
|
||||
base = self.original_base
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.wqkv = nn.Linear(
|
||||
dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=args.bias
|
||||
)
|
||||
self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 2.0
|
||||
)
|
||||
|
||||
self.rope = DynamicNTKScalingRoPE(
|
||||
head_dim,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv_states = self.wqkv(x)
|
||||
qkv_states = qkv_states.reshape(B, L, -1, 2 + self.n_kv_groups, self.head_dim)
|
||||
|
||||
queries = qkv_states[..., : self.n_kv_groups, :]
|
||||
queries = queries.reshape(B, L, -1, self.head_dim)
|
||||
keys = qkv_states[..., -2, :]
|
||||
values = qkv_states[..., -1, :]
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.attention_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class InternLM2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
assert args.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = InternLM2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.tok_embeddings.as_linear(out)
|
||||
else:
|
||||
out = self.output(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,241 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
bias: bool = False
|
||||
qkv_bias: bool = False
|
||||
max_position_embeddings: int = 32768
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "rope_type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic"
|
||||
)
|
||||
|
||||
|
||||
class DynamicNTKScalingRoPE(nn.Module):
|
||||
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_base = base
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.scale = scale
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.original_base * (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
else:
|
||||
base = self.original_base
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
qkv_bias = args.qkv_bias
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None
|
||||
and args.rope_scaling["rope_type"] == "linear"
|
||||
else 2.0
|
||||
)
|
||||
|
||||
self.rope = DynamicNTKScalingRoPE(
|
||||
head_dim,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, bias):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class InternLM2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
assert args.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = InternLM2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,208 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
args.rope_traditional,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
weights = {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
if self.args.tie_word_embeddings:
|
||||
weights.pop("lm_head.weight", None)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,242 +0,0 @@
|
||||
# Copyright © 2024-2025 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .cache import MambaCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
state_size: int
|
||||
num_hidden_layers: int
|
||||
conv_kernel: int
|
||||
use_bias: bool
|
||||
use_conv_bias: bool
|
||||
time_step_rank: int
|
||||
tie_word_embeddings: bool = True
|
||||
use_bcdt_rms: bool = False
|
||||
mixer_rms_eps: float = 1e-6
|
||||
|
||||
def __post_init__(self):
|
||||
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
||||
self.hidden_size = self.d_model
|
||||
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
|
||||
self.intermediate_size = self.d_inner
|
||||
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
|
||||
self.state_size = self.d_state
|
||||
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
|
||||
self.num_hidden_layers = self.n_layer
|
||||
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
|
||||
self.num_hidden_layers = self.n_layers
|
||||
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
|
||||
self.conv_kernel = self.d_conv
|
||||
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
|
||||
self.use_bias = self.bias
|
||||
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
|
||||
self.use_conv_bias = self.conv_bias
|
||||
|
||||
if self.time_step_rank == "auto":
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||
if self.model_type == "falcon_mamba":
|
||||
self.use_bcdt_rms = True
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.weight = mx.random.normal((self.channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,)) if bias else None
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
groups, K, _ = self.weight.shape
|
||||
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=groups)
|
||||
|
||||
if self.bias is not None:
|
||||
y = y + self.bias
|
||||
|
||||
return y, x[:, -K + 1 :, :]
|
||||
|
||||
|
||||
class MambaBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.ssm_state_size = args.state_size
|
||||
self.conv_kernel_size = args.conv_kernel
|
||||
self.intermediate_size = args.intermediate_size
|
||||
self.time_step_rank = int(args.time_step_rank)
|
||||
self.use_conv_bias = args.use_conv_bias
|
||||
self.use_bcdt_rms = args.use_bcdt_rms
|
||||
if self.use_bcdt_rms:
|
||||
self.mixer_norm = lambda x: mx.fast.rms_norm(
|
||||
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
|
||||
)
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
||||
)
|
||||
|
||||
self.conv1d = DepthWiseConv1d(
|
||||
channels=self.intermediate_size,
|
||||
kernel_size=self.conv_kernel_size,
|
||||
bias=self.use_conv_bias,
|
||||
padding=self.conv_kernel_size - 1,
|
||||
)
|
||||
|
||||
self.x_proj = nn.Linear(
|
||||
self.intermediate_size,
|
||||
self.time_step_rank + 2 * self.ssm_state_size,
|
||||
bias=False,
|
||||
)
|
||||
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
||||
|
||||
A = mx.repeat(
|
||||
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
|
||||
repeats=self.intermediate_size,
|
||||
axis=0,
|
||||
)
|
||||
self.A_log = mx.log(A)
|
||||
self.D = mx.ones([self.intermediate_size])
|
||||
|
||||
self.out_proj = nn.Linear(
|
||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||
)
|
||||
|
||||
def ssm_step(self, x, A, state=None):
|
||||
D = self.D
|
||||
deltaBC = self.x_proj(x)
|
||||
delta, B, C = map(
|
||||
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||
mx.split(
|
||||
deltaBC,
|
||||
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
||||
axis=-1,
|
||||
),
|
||||
)
|
||||
if self.use_bcdt_rms:
|
||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||
delta = nn.softplus(self.dt_proj(delta))
|
||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||
if state is not None:
|
||||
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
||||
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
|
||||
y = y + D * x
|
||||
return y, new_state
|
||||
|
||||
def _process_sequence(self, x, conv_cache, state_cache):
|
||||
B, T, D = x.shape
|
||||
xz = self.in_proj(x)
|
||||
x, z = xz.split(indices_or_sections=2, axis=-1)
|
||||
|
||||
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
||||
x = nn.silu(conv_out)
|
||||
|
||||
A = -mx.exp(self.A_log)
|
||||
|
||||
outputs = []
|
||||
current_state = state_cache
|
||||
y = []
|
||||
for t in range(T):
|
||||
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
||||
y.append(y_t)
|
||||
y = mx.stack(y, axis=1)
|
||||
z = self.out_proj(nn.silu(z) * y)
|
||||
return z, (new_conv_cache, current_state)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
if cache is None:
|
||||
conv_cache, state_cache = None, None
|
||||
else:
|
||||
conv_cache, state_cache = cache[0], cache[1]
|
||||
|
||||
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||
x, conv_cache, state_cache
|
||||
)
|
||||
|
||||
if isinstance(cache, MambaCache):
|
||||
cache[0] = new_conv_cache
|
||||
cache[1] = new_state_cache
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.mixer = MambaBlock(args)
|
||||
self.norm = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array, cache):
|
||||
return self.mixer(self.norm(x), cache) + x
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm_f = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array, cache):
|
||||
x = self.embeddings(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, c)
|
||||
return self.norm_f(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.backbone = Mamba(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache=None):
|
||||
B, T = inputs.shape
|
||||
|
||||
x = self.backbone(inputs, cache)
|
||||
|
||||
if self.args.tie_word_embeddings:
|
||||
logits = self.backbone.embeddings.as_linear(x)
|
||||
else:
|
||||
logits = self.lm_head(x)
|
||||
|
||||
return logits
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
return [MambaCache() for _ in range(len(self.layers))]
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.backbone.layers
|
||||
@@ -1,210 +0,0 @@
|
||||
# Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
dim_model_base: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
scale_depth: float
|
||||
scale_emb: float
|
||||
rope_theta: float = 1000000.0
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[str, float]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_heads = n_heads = args.num_attention_heads
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
dims=self.head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=self.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
attn_output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.o_proj(attn_output)
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
self.scale_depth = args.scale_depth
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||
return out
|
||||
|
||||
|
||||
class MiniCPMModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
assert self.vocab_size > 0
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = MiniCPMModel(args)
|
||||
|
||||
if not self.args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
|
||||
if not self.args.tie_word_embeddings:
|
||||
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
||||
else:
|
||||
out = out @ self.model.embed_tokens.weight.T
|
||||
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,220 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int = 32000
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 14336
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_experts_per_tok: int = 2
|
||||
num_key_value_heads: int = 8
|
||||
num_local_experts: int = 8
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 1e6
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_heads = args.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MixtralSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_dim = args.hidden_size
|
||||
self.ffn_dim = args.intermediate_size
|
||||
self.num_experts = args.num_local_experts
|
||||
self.num_experts_per_tok = args.num_experts_per_tok
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
|
||||
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gates = self.gate(x)
|
||||
|
||||
k = self.num_experts_per_tok
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = MixtralAttention(args)
|
||||
|
||||
self.block_sparse_moe = MixtralSparseMoeBlock(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.block_sparse_moe(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class MixtralModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = MixtralModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(
|
||||
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
|
||||
)
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
|
||||
mx.stack(to_join)
|
||||
)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,220 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
hidden_act: str
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
partial_rotary_factor: float = 0.5
|
||||
rope_theta: float = 10000.0
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.rope_scaling:
|
||||
if not "factor" in self.rope_scaling:
|
||||
raise ValueError(f"rope_scaling must contain 'factor'")
|
||||
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
||||
"rope_type"
|
||||
)
|
||||
if rope_type is None:
|
||||
raise ValueError(
|
||||
f"rope_scaling must contain either 'type' or 'rope_type'"
|
||||
)
|
||||
if rope_type not in ["linear"]:
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu_squared(x):
|
||||
return nn.relu(x).square()
|
||||
|
||||
|
||||
class NemotronLayerNorm1P(nn.LayerNorm):
|
||||
def __call__(self, x):
|
||||
weight = self.weight + 1 if "weight" in self else None
|
||||
bias = self.bias if "bias" in self else None
|
||||
return mx.fast.layer_norm(x, weight, bias, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
self.partial_rotary_factor = args.partial_rotary_factor
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||
assert isinstance(args.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / args.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
mlp_bias = args.mlp_bias
|
||||
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(relu_squared(self.up_proj(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
|
||||
self.post_attention_layernorm = NemotronLayerNorm1P(
|
||||
args.hidden_size, eps=args.norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class NemotronModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = NemotronModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,180 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
d_model: int
|
||||
n_layers: int
|
||||
mlp_hidden_size: int
|
||||
n_heads: int
|
||||
vocab_size: int
|
||||
embedding_size: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
mlp_ratio: int = 4
|
||||
weight_tying: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.mlp_hidden_size = (
|
||||
self.mlp_hidden_size
|
||||
if self.mlp_hidden_size is not None
|
||||
else self.mlp_ratio * self.d_model
|
||||
)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
dim = args.d_model
|
||||
|
||||
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
|
||||
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
|
||||
|
||||
self.att_norm = nn.LayerNorm(dim, affine=False)
|
||||
self.ff_norm = nn.LayerNorm(dim, affine=False)
|
||||
|
||||
head_dim = dim // self.n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.att_proj = nn.Linear(dim, 3 * dim, bias=False)
|
||||
self.attn_out = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
self.args = args
|
||||
|
||||
def attend(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.attn_out(output)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attend(self.att_norm(x), mask, cache)
|
||||
h = x + r
|
||||
|
||||
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
||||
|
||||
out = h + self.ff_out(nn.silu(x2) * x1)
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_layers = args.n_layers
|
||||
self.weight_tying = args.weight_tying
|
||||
|
||||
self.wte = nn.Embedding(args.embedding_size, args.d_model)
|
||||
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
if not self.weight_tying:
|
||||
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||
self.norm = nn.LayerNorm(args.d_model, affine=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
||||
for block, c in zip(self.blocks, cache):
|
||||
h = block(h, mask, c)
|
||||
|
||||
h = self.norm(h)
|
||||
|
||||
if self.weight_tying:
|
||||
return self.wte.as_linear(h), cache
|
||||
|
||||
return self.ff_out(h)
|
||||
|
||||
|
||||
class OlmoModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.transformer = Transformer(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
return self.transformer(inputs, mask, cache)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = OlmoModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
return self.model(inputs, mask, cache)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.transformer.blocks
|
||||
@@ -1,212 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
args.rope_traditional,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
|
||||
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.post_attention_layernorm(self.self_attn(x, mask, cache))
|
||||
h = x + r
|
||||
r = self.post_feedforward_layernorm(self.mlp(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,217 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .rope_utils import initialize_rope
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_experts: int
|
||||
num_experts_per_tok: int
|
||||
norm_topk_prob: bool = False
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
||||
|
||||
self.rope = initialize_rope(
|
||||
self.head_dim,
|
||||
args.rope_theta,
|
||||
args.rope_traditional,
|
||||
args.rope_scaling,
|
||||
args.max_position_embeddings,
|
||||
)
|
||||
|
||||
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
|
||||
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class OlmoeSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_experts = args.num_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
self.norm_topk_prob = args.norm_topk_prob
|
||||
|
||||
self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
|
||||
self.switch_mlp = SwitchGLU(
|
||||
args.hidden_size,
|
||||
args.intermediate_size,
|
||||
self.num_experts,
|
||||
bias=args.mlp_bias,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
x_flat = x.reshape(-1, D)
|
||||
router_logits = self.gate(x_flat)
|
||||
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
|
||||
k = self.top_k
|
||||
indices = mx.stop_gradient(
|
||||
mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k]
|
||||
)
|
||||
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
|
||||
if self.norm_topk_prob:
|
||||
scores = scores / scores.sum(axis=-1, keepdims=True)
|
||||
y = self.switch_mlp(x_flat, indices)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
return y.reshape(B, L, D)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = OlmoeSparseMoeBlock(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
x = x + self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
x = x + self.mlp(self.post_attention_layernorm(x))
|
||||
return x
|
||||
|
||||
|
||||
class OlmoeModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = OlmoeModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,223 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
head_dim: int
|
||||
num_transformer_layers: int
|
||||
model_dim: int
|
||||
vocab_size: int
|
||||
ffn_dim_divisor: int
|
||||
num_query_heads: List
|
||||
num_kv_heads: List
|
||||
ffn_multipliers: List
|
||||
ffn_with_glu: bool = True
|
||||
normalize_qk_projections: bool = True
|
||||
share_input_output_layers: bool = True
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_freq_constant: float = 10000
|
||||
|
||||
|
||||
def make_divisible(
|
||||
v: Union[float, int],
|
||||
divisor: Optional[int] = 8,
|
||||
min_value: Optional[Union[float, int]] = None,
|
||||
) -> Union[float, int]:
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by the divisor
|
||||
It can be seen at:
|
||||
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
|
||||
Args:
|
||||
v: input value
|
||||
divisor: default to 8
|
||||
min_value: minimum divisor value
|
||||
Returns:
|
||||
new_v: new divisible value
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
self.layer_id = layer_id
|
||||
self.model_dim = model_dim = args.model_dim
|
||||
|
||||
self.n_heads = n_heads = args.num_query_heads[layer_id]
|
||||
self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id]
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
op_size = (n_heads + (n_kv_heads * 2)) * head_dim
|
||||
self.qkv_proj = nn.Linear(model_dim, op_size, bias=False)
|
||||
self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False)
|
||||
|
||||
self.normalize_qk_projections = args.normalize_qk_projections
|
||||
|
||||
if self.normalize_qk_projections:
|
||||
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
|
||||
qkv = qkv.reshape(
|
||||
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
if self.normalize_qk_projections:
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
dim = args.model_dim
|
||||
ffn_multiplier = args.ffn_multipliers[layer_id]
|
||||
|
||||
intermediate_dim = int(
|
||||
make_divisible(
|
||||
ffn_multiplier * args.model_dim,
|
||||
divisor=args.ffn_dim_divisor,
|
||||
)
|
||||
)
|
||||
|
||||
self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False)
|
||||
self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.proj_1(x)
|
||||
gate, x = mx.split(x, 2, axis=-1)
|
||||
return self.proj_2(nn.silu(gate) * x)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
dim = args.model_dim
|
||||
self.attn = Attention(args, layer_id=layer_id)
|
||||
self.ffn = MLP(args, layer_id=layer_id)
|
||||
self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
|
||||
self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.attn_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.ffn(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class OpenELMModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_transformer_layers = args.num_transformer_layers
|
||||
assert self.vocab_size > 0
|
||||
self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim)
|
||||
self.layers = [
|
||||
TransformerBlock(args, layer_id=layer_id)
|
||||
for layer_id in range(self.num_transformer_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.token_embeddings(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = OpenELMModel(args)
|
||||
if not args.share_input_output_layers:
|
||||
self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, mask, cache)
|
||||
if self.args.share_input_output_layers:
|
||||
out = self.transformer.token_embeddings.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.layers
|
||||
@@ -1,179 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "phi"
|
||||
max_position_embeddings: int = 2048
|
||||
vocab_size: int = 51200
|
||||
hidden_size: int = 2560
|
||||
num_attention_heads: int = 32
|
||||
num_hidden_layers: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
partial_rotary_factor: float = 0.4
|
||||
intermediate_size: int = 10240
|
||||
layer_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class PhiAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.repeats = self.num_heads // self.num_key_value_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.dense = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=True
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
B, L, D = queries.shape
|
||||
n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(
|
||||
B,
|
||||
L,
|
||||
n_heads,
|
||||
-1,
|
||||
).moveaxis(1, 2)
|
||||
keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||
values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
output = scaled_dot_product_attention(
|
||||
queries.astype(mx.float32),
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
).astype(values.dtype)
|
||||
|
||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class PhiDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = PhiAttention(config=config)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = PhiMLP(config)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class PhiModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, mask, c)
|
||||
return self.final_layernorm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.model = PhiModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,215 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .su_rope import SuScaledRotaryEmbedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
|
||||
partial_rotary_factor: float = 1.0
|
||||
max_position_embeddings: int = 131072
|
||||
original_max_position_embeddings: int = 4096
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"long_factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
|
||||
print(
|
||||
"[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
|
||||
)
|
||||
self.rope_scaling = None
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_dim = int(head_dim * args.partial_rotary_factor)
|
||||
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
||||
self.rope = SuScaledRotaryEmbedding(
|
||||
rope_dim,
|
||||
base=args.rope_theta,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||
short_factor=args.rope_scaling["short_factor"],
|
||||
long_factor=args.rope_scaling["long_factor"],
|
||||
)
|
||||
else:
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||
assert isinstance(args.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / args.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
rope_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
query_pos = self.n_heads * self.head_dim
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.gate_up_proj(x)
|
||||
gate, x = mx.split(x, 2, axis=-1)
|
||||
return self.down_proj(nn.silu(gate) * x)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Phi3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,313 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
dense_attention_every_n_layers: int
|
||||
ff_intermediate_size: int
|
||||
gegelu_limit: float
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
mup_attn_multiplier: float = 1.0
|
||||
mup_use_scaling: bool = True
|
||||
mup_embedding_multiplier: float = 10.0
|
||||
mup_width_multiplier: float = 8.0
|
||||
rope_embedding_base: float = 1000000
|
||||
rope_position_scale: float = 1.0
|
||||
blocksparse_block_size: int = 64
|
||||
blocksparse_num_local_blocks: int = 16
|
||||
blocksparse_vert_stride: int = 8
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gegelu_impl(a_gelu, a_linear, limit):
|
||||
a_gelu = mx.where(
|
||||
mx.isinf(a_gelu),
|
||||
a_gelu,
|
||||
mx.clip(a_gelu, a_min=None, a_max=limit),
|
||||
)
|
||||
a_linear = mx.where(
|
||||
mx.isinf(a_linear),
|
||||
a_linear,
|
||||
mx.clip(a_linear, a_min=-limit, a_max=limit),
|
||||
)
|
||||
out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
|
||||
return out_gelu * (a_linear + 1.0)
|
||||
|
||||
|
||||
def gegelu(x, limit):
|
||||
a_gelu, a_linear = x[..., ::2], x[..., 1::2]
|
||||
return gegelu_impl(a_gelu, a_linear, limit)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_q_per_kv = n_heads // n_kv_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
|
||||
self.query_key_value = nn.Linear(
|
||||
dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
|
||||
)
|
||||
self.dense = nn.Linear(dim, dim)
|
||||
|
||||
if args.mup_use_scaling:
|
||||
norm_factor = head_dim / args.mup_attn_multiplier
|
||||
else:
|
||||
norm_factor = math.sqrt(head_dim)
|
||||
self.scale = 1.0 / norm_factor
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=False,
|
||||
base=args.rope_embedding_base,
|
||||
scale=args.rope_position_scale,
|
||||
)
|
||||
|
||||
if layer_idx % args.dense_attention_every_n_layers == 0:
|
||||
self.block_sparse = True
|
||||
self.blocksparse_block_size = args.blocksparse_block_size
|
||||
if self.blocksparse_block_size not in (32, 64):
|
||||
raise ValueError(
|
||||
f"Unsupported block size {self.blocksparse_block_size}"
|
||||
)
|
||||
self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
|
||||
self.blocksparse_vert_stride = args.blocksparse_vert_stride
|
||||
else:
|
||||
self.block_sparse = False
|
||||
|
||||
def _block_sparse_mask(self, q_len, kv_len):
|
||||
vert_stride = self.blocksparse_vert_stride
|
||||
local_blocks = self.blocksparse_num_local_blocks
|
||||
block_size = self.blocksparse_block_size
|
||||
n_heads = self.n_heads
|
||||
|
||||
kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
q_blocks = (q_len + block_size - 1) // block_size
|
||||
q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
|
||||
k_pos = mx.arange(kv_blocks)[None, None]
|
||||
|
||||
mask_vert_strided = (
|
||||
mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
|
||||
) % vert_stride
|
||||
mask_vert_strided = (mask_vert_strided == 0)[:, None, :]
|
||||
|
||||
block_mask = (q_pos >= k_pos) & (
|
||||
(q_pos - k_pos < local_blocks) | mask_vert_strided
|
||||
)
|
||||
block_mask = block_mask.reshape(
|
||||
self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
|
||||
)
|
||||
dense_mask = mx.repeat(
|
||||
mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
|
||||
)
|
||||
return block_mask, dense_mask[..., -q_len:, :kv_len]
|
||||
|
||||
def _block_sparse_attention(self, queries, keys, values, scale, mask):
|
||||
queries = scale * queries
|
||||
B = queries.shape[0]
|
||||
L = queries.shape[2]
|
||||
queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
|
||||
keys = mx.expand_dims(keys, 2)
|
||||
values = mx.expand_dims(values, 2)
|
||||
|
||||
# TODO get rid of dense mask if we have a fill value
|
||||
block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
|
||||
scores = queries @ mx.swapaxes(keys, -1, -2)
|
||||
# TODO, uncomment when faster
|
||||
# scores = mx.block_masked_mm(
|
||||
# queries,
|
||||
# mx.swapaxes(keys, -1, -2),
|
||||
# mask_out=block_mask,
|
||||
# block_size=self.blocksparse_block_size,
|
||||
# )
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = scores + mx.where(
|
||||
dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
|
||||
)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
output = scores @ values
|
||||
# TODO, uncomment when faster
|
||||
# output = mx.block_masked_mm(
|
||||
# scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
|
||||
# )
|
||||
return mx.reshape(output, (B, self.n_heads, L, -1))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.query_key_value(x)
|
||||
qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
|
||||
queries = qkv[..., :-2, :].flatten(-3, -2)
|
||||
keys = qkv[..., -2, :]
|
||||
values = qkv[..., -1, :]
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
if self.block_sparse:
|
||||
output = self._block_sparse_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.ff_intermediate_size
|
||||
self.gegelu_limit = args.gegelu_limit
|
||||
self.up_proj = nn.Linear(dim, 2 * hidden_dim)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.up_proj(x)
|
||||
return self.down_proj(gegelu(x, self.gegelu_limit))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args, layer_idx)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_epsilon
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layer_norm_epsilon,
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Phi3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.mup_embedding_multiplier = args.mup_embedding_multiplier
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args, layer_idx=l)
|
||||
for l in range(args.num_hidden_layers)
|
||||
]
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_epsilon
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
if self.mup_embedding_multiplier:
|
||||
h = self.mup_embedding_multiplier * h
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.final_layernorm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
self.args = args
|
||||
self.mup_width_multiplier = args.mup_width_multiplier
|
||||
self._dummy_tokenizer_ids = mx.array(
|
||||
[100256, 100258, 100259, 100260, 100264, 100265]
|
||||
+ list(range(100267, 100352))
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
if self.mup_width_multiplier:
|
||||
out = out / self.mup_width_multiplier
|
||||
out[self._dummy_tokenizer_ids] = -float("inf")
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .su_rope import SuScaledRotaryEmbedding
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "phimoe"
|
||||
vocab_size: int = 32064
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 6400
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 131072
|
||||
original_max_position_embeddings: int = 4096
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_scaling: Dict[str, Union[float, List[float]]] = None
|
||||
num_local_experts: int = 16
|
||||
num_experts_per_tok: int = 2
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
||||
|
||||
self.rope = SuScaledRotaryEmbedding(
|
||||
head_dim,
|
||||
base=args.rope_theta,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||
short_factor=args.rope_scaling["short_factor"],
|
||||
long_factor=args.rope_scaling["long_factor"],
|
||||
short_mscale=args.rope_scaling["short_mscale"],
|
||||
long_mscale=args.rope_scaling["long_mscale"],
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class PhiMoESparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_dim = args.hidden_size
|
||||
self.ffn_dim = args.intermediate_size
|
||||
self.num_experts = args.num_local_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gates = self.gate(x)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class PhiMoEDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.block_sparse_moe = PhiMoESparseMoeBlock(args)
|
||||
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
hidden_states = self.input_layernorm(x)
|
||||
hidden_states = self.self_attn(hidden_states, mask=mask, cache=cache)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PhiMoEModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.args = args
|
||||
self.model = PhiMoEModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(
|
||||
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
|
||||
)
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
|
||||
mx.stack(to_join)
|
||||
)
|
||||
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,202 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchMLP
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
model_type: str
|
||||
num_vocab: int = 51200
|
||||
model_dim: int = 2560
|
||||
num_heads: int = 32
|
||||
num_layers: int = 32
|
||||
rotary_dim: int = 32
|
||||
num_experts_per_tok: int = 2
|
||||
num_local_experts: int = 4
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.Wqkv(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries.astype(mx.float32),
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
).astype(values.dtype)
|
||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class MOE(nn.Module):
|
||||
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_experts = args.num_local_experts
|
||||
self.num_experts_per_tok = args.num_experts_per_tok
|
||||
self.switch_mlp = SwitchMLP(
|
||||
self.dim, self.hidden_dim, self.num_experts, bias=True
|
||||
)
|
||||
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gates = self.gate(x)
|
||||
|
||||
k = self.num_experts_per_tok
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = nn.LayerNorm(dims)
|
||||
self.moe = MOE(config, dims, mlp_dims)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.ln(x)
|
||||
attn_h = self.mixer(h, mask, cache)
|
||||
ff_h = self.moe(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embd = Embd(config)
|
||||
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embd(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
x = layer(x, mask, c)
|
||||
return x
|
||||
|
||||
|
||||
class Embd(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.wte(x)
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.transformer = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_layers):
|
||||
prefix = f"transformer.h.{l}"
|
||||
for n in ["fc1", "fc2"]:
|
||||
for k in ["weight", "scales", "biases", "bias"]:
|
||||
if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
@@ -1,214 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
n_shared_head: int = 8
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
head_dim = self.hidden_size // config.num_attention_heads
|
||||
|
||||
self.q_num_heads = config.num_attention_heads
|
||||
self.qk_dim = self.v_dim = head_dim
|
||||
self.k_num_heads = self.v_num_heads = int(
|
||||
np.ceil(self.q_num_heads / config.n_shared_head)
|
||||
)
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.q_num_heads * self.qk_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.k_num_heads * self.qk_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.v_num_heads * self.v_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||
)
|
||||
self.rotary_emb = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=config.rope_traditional,
|
||||
base=config.rope_theta,
|
||||
scale=1.0,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
|
||||
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self.scale,
|
||||
mask=attention_mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
|
||||
|
||||
|
||||
class PlamoDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
# from LlamaDecoder
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states_sa = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states_sa + hidden_states_mlp
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
PlamoDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
|
||||
class PlamoModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
|
||||
for layer, c in zip(self.layers.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = PlamoModel(args)
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
args.hidden_size, args.vocab_size, bias=False
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
@@ -1,608 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.base import BaseModelArgs, create_attention_mask
|
||||
|
||||
from .cache import KVCache, MambaCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "plamo2"
|
||||
hidden_size: int = 4096
|
||||
num_hidden_layers: int = 32
|
||||
rms_norm_eps: float = 1e-6
|
||||
tie_word_embeddings: bool = True
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 4
|
||||
hidden_size_per_head: int = 128
|
||||
max_position_embeddings: int = 2048
|
||||
attention_window_size: int = 2048
|
||||
full_attention_idx: Optional[list[int]] = None
|
||||
mamba_d_state: int = 64
|
||||
mamba_d_conv: int = 4
|
||||
mamba_num_heads: int = 64
|
||||
mamba_step: int = 2
|
||||
mamba_chunk_size: int = 256
|
||||
mamba_enabled: bool = True
|
||||
intermediate_size: int = 13312
|
||||
vocab_size: int = 32000
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
offset: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = mx.zeros(hidden_size)
|
||||
self.variance_epsilon = eps
|
||||
self.offset = offset
|
||||
|
||||
def __call__(self, hidden_states: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(
|
||||
hidden_states, self.weight + self.offset, self.variance_epsilon
|
||||
)
|
||||
|
||||
|
||||
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.astype(mx.float32)
|
||||
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
|
||||
hidden_states = hidden_states * mx.rsqrt(variance + eps)
|
||||
hidden_states = hidden_states.astype(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_initial_dt_bias(num_heads: int) -> mx.array:
|
||||
dt_min = 0.001
|
||||
dt_max = 0.1
|
||||
dt = mx.exp(
|
||||
mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
)
|
||||
dt = mx.clip(dt, a_min=1e-4, a_max=None)
|
||||
inv_dt = dt + mx.log(-mx.expm1(-dt))
|
||||
return inv_dt
|
||||
|
||||
|
||||
def get_initial_A(num_heads: int) -> mx.array:
|
||||
A = mx.arange(1, num_heads + 1, dtype=mx.float32)
|
||||
return mx.log(A)
|
||||
|
||||
|
||||
# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219
|
||||
def selective_state_update_ref(
|
||||
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.ndim > 3
|
||||
if state.ndim == 3:
|
||||
state = mx.expand_dims(state, 1)
|
||||
if x.ndim == 2:
|
||||
x = mx.expand_dims(x, 1)
|
||||
if dt.ndim == 2:
|
||||
dt = mx.expand_dims(dt, 1)
|
||||
if A.ndim == 2:
|
||||
A = mx.expand_dims(A, 0)
|
||||
if B.ndim == 2:
|
||||
B = mx.expand_dims(B, 1)
|
||||
if C.ndim == 2:
|
||||
C = mx.expand_dims(C, 1)
|
||||
if D is not None and D.ndim == 1:
|
||||
D = mx.expand_dims(D, 0)
|
||||
if z is not None and z.ndim == 2:
|
||||
z = mx.expand_dims(z, 1)
|
||||
if dt_bias is not None and dt_bias.ndim == 1:
|
||||
dt_bias = mx.expand_dims(dt_bias, 0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = nn.softplus(dt) if dt_softplus else dt
|
||||
dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
|
||||
B = mx.reshape(
|
||||
mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
C = mx.reshape(
|
||||
mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(
|
||||
B, axis=-2
|
||||
) # (batch, nheads, dim, dstate)
|
||||
state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate)
|
||||
out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).astype(out.dtype)
|
||||
out = (out if z is None else out * nn.silu(z)).astype(x.dtype)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out, state
|
||||
|
||||
|
||||
def ssd_update_state(
|
||||
ssm_state: mx.array,
|
||||
x: mx.array,
|
||||
dt: mx.array,
|
||||
A: mx.array,
|
||||
B: mx.array,
|
||||
C: mx.array,
|
||||
D: mx.array,
|
||||
z: mx.array,
|
||||
dt_bias: mx.array,
|
||||
dt_softplus: bool,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
assert ssm_state.dtype == mx.float32
|
||||
dtype = x.dtype
|
||||
|
||||
hidden_size_per_head = x.shape[-1]
|
||||
d_state = B.shape[-1]
|
||||
A = mx.broadcast_to(
|
||||
A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)
|
||||
).astype(mx.float32)
|
||||
dt = mx.broadcast_to(
|
||||
dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)
|
||||
)
|
||||
dt_bias = mx.broadcast_to(
|
||||
dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)
|
||||
)
|
||||
D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
|
||||
out, ssm_state = selective_state_update_ref(
|
||||
ssm_state,
|
||||
x.astype(dtype),
|
||||
dt.astype(dtype),
|
||||
A.astype(mx.float32),
|
||||
B.astype(dtype),
|
||||
C.astype(dtype),
|
||||
D.astype(mx.float32),
|
||||
z.astype(dtype),
|
||||
dt_bias.astype(mx.float32),
|
||||
dt_softplus=dt_softplus,
|
||||
)
|
||||
return out[:, None], ssm_state
|
||||
|
||||
|
||||
def ssd_chunk_scan_combined(
|
||||
x: mx.array,
|
||||
dt: mx.array,
|
||||
A: mx.array,
|
||||
B: mx.array,
|
||||
C: mx.array,
|
||||
D: mx.array,
|
||||
z: mx.array,
|
||||
dt_bias: mx.array,
|
||||
dt_softplus: bool,
|
||||
ssm_state: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
assert ssm_state.dtype == mx.float32
|
||||
length = x.shape[1]
|
||||
ys = []
|
||||
for i in range(length):
|
||||
y, ssm_state = ssd_update_state(
|
||||
ssm_state,
|
||||
x[:, i],
|
||||
dt[:, i],
|
||||
A,
|
||||
B[:, i],
|
||||
C[:, i],
|
||||
D if D.ndim == 1 else D[:, i],
|
||||
z=z[:, i],
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
)
|
||||
ys.append(y)
|
||||
return mx.concatenate(ys, axis=1), ssm_state
|
||||
|
||||
|
||||
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
|
||||
_, seqlen, dim = x.shape
|
||||
state_len = conv_state.shape[-2]
|
||||
x = mx.concatenate([conv_state, x], axis=-2)
|
||||
conv_state = x[:, -state_len:]
|
||||
out = mx.conv1d(
|
||||
x,
|
||||
weight,
|
||||
padding=0,
|
||||
groups=dim,
|
||||
)[:, -seqlen:]
|
||||
return nn.silu(out), conv_state
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.d_state = config.mamba_d_state
|
||||
self.d_conv = config.mamba_d_conv
|
||||
self.chunk_size = config.mamba_chunk_size
|
||||
self.num_heads = config.mamba_num_heads
|
||||
self.hidden_size_per_head = config.hidden_size_per_head
|
||||
|
||||
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
self.hidden_size, 2 * self.intermediate_size, bias=False
|
||||
)
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.intermediate_size,
|
||||
out_channels=self.intermediate_size,
|
||||
bias=False,
|
||||
kernel_size=self.d_conv,
|
||||
groups=self.intermediate_size,
|
||||
padding=0,
|
||||
)
|
||||
self.dt_dim = max(64, self.hidden_size // 16)
|
||||
self.bcdt_proj = nn.Linear(
|
||||
self.intermediate_size,
|
||||
self.dt_dim + 2 * self.d_state,
|
||||
bias=False,
|
||||
)
|
||||
self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
||||
|
||||
self.dt_bias = get_initial_dt_bias(self.num_heads)
|
||||
self.A_log = get_initial_A(self.num_heads)
|
||||
self.D = mx.ones(self.num_heads, dtype=mx.float32)
|
||||
|
||||
self.dt_norm_weight = mx.ones(self.dt_dim)
|
||||
self.B_norm_weight = mx.ones(self.d_state)
|
||||
self.C_norm_weight = mx.ones(self.d_state)
|
||||
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
bsize, length, _ = hidden_states.shape
|
||||
|
||||
if cache is not None and cache[0] is not None:
|
||||
conv_state = cache[0]
|
||||
ssm_state = cache[1]
|
||||
else:
|
||||
conv_state = mx.zeros(
|
||||
(bsize, self.d_conv - 1, self.intermediate_size),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
ssm_state = mx.zeros(
|
||||
(bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
|
||||
dtype=mx.float32,
|
||||
)
|
||||
|
||||
zx = self.in_proj(hidden_states)
|
||||
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
||||
# z: (bsize, length, num_heads, hidden_size_per_head)
|
||||
# x: (bsize, length, num_heads, hidden_size_per_head)
|
||||
z, x = mx.split(
|
||||
zx,
|
||||
[
|
||||
self.hidden_size_per_head,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head)
|
||||
x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight)
|
||||
BCdt = self.bcdt_proj(x)
|
||||
x = x.reshape(bsize, length, self.num_heads, -1)
|
||||
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
|
||||
|
||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||
dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps)
|
||||
B = mx.fast.rms_norm(B, self.B_norm_weight, self.config.rms_norm_eps)
|
||||
C = mx.fast.rms_norm(C, self.C_norm_weight, self.config.rms_norm_eps)
|
||||
|
||||
# (bsize, length, num_heads, 1)
|
||||
dt = self.dt_proj(dt)[..., None]
|
||||
|
||||
out, ssm_state = ssd_chunk_scan_combined(
|
||||
x,
|
||||
dt.reshape(bsize, length, -1),
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=self.D,
|
||||
z=z,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
ssm_state=ssm_state,
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
cache[0] = conv_state
|
||||
cache[1] = ssm_state
|
||||
y = self.out_proj(out.reshape(bsize, length, -1))
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
head_dim = config.hidden_size_per_head
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_num_heads = config.num_attention_heads
|
||||
self.qk_dim = self.v_dim = head_dim
|
||||
self.k_num_heads = self.v_num_heads = config.num_key_value_heads
|
||||
assert self.q_num_heads % self.k_num_heads == 0
|
||||
self.n_group = self.q_num_heads // self.k_num_heads
|
||||
|
||||
self.q_proj_dim = self.q_num_heads * self.qk_dim
|
||||
self.k_proj_dim = self.k_num_heads * self.qk_dim
|
||||
self.v_proj_dim = self.k_num_heads * self.v_dim
|
||||
self.qkv_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
|
||||
self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
|
||||
|
||||
self.rope = nn.RoPE(self.qk_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
B, T, _ = hidden_states.shape
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = mx.split(
|
||||
qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
|
||||
)
|
||||
q = q.reshape(B, T, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
|
||||
|
||||
q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
|
||||
k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
|
||||
|
||||
if cache is not None:
|
||||
q = self.rope(q, offset=cache.offset)
|
||||
k = self.rope(k, offset=cache.offset)
|
||||
k, v = cache.update_and_fetch(k, v)
|
||||
else:
|
||||
q = self.rope(q)
|
||||
k = self.rope(k)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale,
|
||||
mask=mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(
|
||||
B, T, self.q_num_heads * self.v_dim
|
||||
)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_up_proj = nn.Linear(
|
||||
self.hidden_size, self.intermediate_size * 2, bias=False
|
||||
)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
h = self.gate_up_proj(x)
|
||||
hs = mx.split(h, 2, axis=-1)
|
||||
return self.down_proj(nn.silu(hs[0]) * hs[1])
|
||||
|
||||
|
||||
class PlamoDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, is_mamba: bool) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.is_mamba = is_mamba
|
||||
self.mixer: nn.Module
|
||||
if is_mamba:
|
||||
self.mixer = Mamba(config)
|
||||
else:
|
||||
self.mixer = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.pre_mixer_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||
)
|
||||
self.post_mixer_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5
|
||||
)
|
||||
self.pre_mlp_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0
|
||||
)
|
||||
self.post_mlp_norm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_mixer_norm(hidden_states)
|
||||
|
||||
hidden_states_sa = self.mixer(
|
||||
hidden_states=hidden_states,
|
||||
mask=mask,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
|
||||
hidden_states = residual + hidden_states_sa
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_mlp_norm(hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
|
||||
# Residual
|
||||
hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
|
||||
return residual + hidden_states_mlp
|
||||
|
||||
|
||||
def is_mamba(config: ModelArgs, i: int) -> bool:
|
||||
if not config.mamba_enabled:
|
||||
return False
|
||||
assert config.mamba_step > 1
|
||||
assert i < config.num_hidden_layers
|
||||
|
||||
if config.num_hidden_layers <= (config.mamba_step // 2):
|
||||
# use attention in last layer
|
||||
return i != config.num_hidden_layers - 1
|
||||
return (i % config.mamba_step) != (config.mamba_step // 2)
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layers = [
|
||||
PlamoDecoderLayer(config, is_mamba=is_mamba(config, i))
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array, cache):
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
x = decoder_layer(
|
||||
x,
|
||||
mask=mask,
|
||||
cache=cache[i],
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class PlamoModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache=None,
|
||||
):
|
||||
batch_size, seq_length = inputs.shape
|
||||
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, [cache[1]] if cache is not None else None)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers.layers)
|
||||
|
||||
# decoder layers
|
||||
out = self.layers(
|
||||
h,
|
||||
mask,
|
||||
cache,
|
||||
)
|
||||
|
||||
return self.norm(out)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_type = config.model_type
|
||||
self.model = PlamoModel(config)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
config.hidden_size, self.vocab_size, bias=False
|
||||
)
|
||||
|
||||
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
# TODO use RotatingKVCache is not full_attn
|
||||
# full_attn = self.layer_idx in self.config.full_attention_idx
|
||||
return [MambaCache() if l.is_mamba else KVCache() for l in self.layers]
|
||||
|
||||
def __call__(
|
||||
self, inputs: mx.array, mask: Optional[mx.array] = None, cache=None
|
||||
) -> mx.array:
|
||||
outputs = self.model(
|
||||
inputs=inputs,
|
||||
mask=None,
|
||||
cache=cache,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
logits = self.model.embed_tokens.as_linear(outputs)
|
||||
else:
|
||||
logits = self.lm_head(outputs)
|
||||
|
||||
return logits
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
@@ -1,159 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 2048
|
||||
num_attention_heads: int = 16
|
||||
num_hidden_layers: int = 24
|
||||
kv_channels: int = 128
|
||||
max_position_embeddings: int = 8192
|
||||
layer_norm_epsilon: float = 1e-6
|
||||
intermediate_size: int = 11008
|
||||
no_bias: bool = True
|
||||
vocab_size: int = 151936
|
||||
num_key_value_heads = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
hidden_size = args.hidden_size
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
|
||||
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
|
||||
|
||||
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
|
||||
|
||||
proj_size = args.kv_channels * self.num_attention_heads
|
||||
|
||||
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
|
||||
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
|
||||
|
||||
self.scale = hidden_size_per_attention_head**-0.5
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.c_attn(x)
|
||||
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
B, L, _ = q.shape
|
||||
|
||||
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(
|
||||
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||
)
|
||||
self.c_proj = nn.Linear(
|
||||
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
a1 = self.w1(x)
|
||||
a2 = self.w2(x)
|
||||
return self.c_proj(a1 * nn.silu(a2))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
residual = x
|
||||
x = self.ln_1(x)
|
||||
x = self.attn(x, mask=mask, cache=cache)
|
||||
residual = x + residual
|
||||
x = self.ln_2(residual)
|
||||
x = self.mlp(x)
|
||||
x = x + residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class QwenModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
x = layer(x, mask, c)
|
||||
|
||||
return self.ln_f(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.transformer = QwenModel(config)
|
||||
self.lm_head = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
||||
)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
@@ -1,201 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 1000000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if self.args.tie_word_embeddings:
|
||||
weights.pop("lm_head.weight", None)
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,241 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_experts_per_tok: int
|
||||
num_experts: int
|
||||
moe_intermediate_size: int
|
||||
shared_expert_intermediate_size: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 1000000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
intermediate_size = args.moe_intermediate_size
|
||||
shared_expert_intermediate_size = args.shared_expert_intermediate_size
|
||||
|
||||
self.num_experts = num_experts = args.num_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
|
||||
self.gate = nn.Linear(dim, num_experts, bias=False)
|
||||
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
|
||||
|
||||
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
|
||||
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
):
|
||||
gates = self.gate(x)
|
||||
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
shared_expert_output = self.shared_expert(x)
|
||||
shared_expert_output = (
|
||||
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
|
||||
)
|
||||
|
||||
return y + shared_expert_output
|
||||
|
||||
|
||||
class Qwen2MoeDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(args)
|
||||
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2MoeModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
Qwen2MoeDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2MoeModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,458 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .cache import MambaCache, RotatingKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
attention_bias: bool
|
||||
conv1d_width: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
logits_soft_cap: float
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
attention_window_size: int
|
||||
vocab_size: int
|
||||
embeddings_scale_by_sqrt_dim: bool = True
|
||||
block_types: Optional[List[str]] = None
|
||||
_block_types: Optional[List[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# For some reason these have different names in 2B and 9B
|
||||
if self.block_types is None:
|
||||
self.block_types = self._block_types
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
def rnn_scan(x, a, h0):
|
||||
assert x.ndim == 3
|
||||
assert a.shape == x.shape[-a.ndim :]
|
||||
assert a.dtype == x.dtype
|
||||
|
||||
if x.shape[1] == 1:
|
||||
# Using scan in sampling mode.
|
||||
if h0 is None:
|
||||
return x, x[:, 0]
|
||||
|
||||
else:
|
||||
y = a * h0[:, None] + x
|
||||
return y, y[:, -1]
|
||||
|
||||
else:
|
||||
# Using scan in linear mode.
|
||||
if h0 is not None:
|
||||
h_t = h0
|
||||
else:
|
||||
B, _, D = x.shape
|
||||
h_t = mx.zeros((B, D), dtype=x.dtype)
|
||||
|
||||
y = mx.zeros_like(x)
|
||||
for t in range(x.shape[1]):
|
||||
h_t = a[:, t] * h_t + x[:, t]
|
||||
y[:, t] = h_t
|
||||
|
||||
return y, h_t
|
||||
|
||||
|
||||
class Conv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,))
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
B, L, C = x.shape
|
||||
groups, K, _ = self.weight.shape
|
||||
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=groups)
|
||||
y = y + self.bias
|
||||
|
||||
return y, x[:, -K + 1 :, :]
|
||||
|
||||
|
||||
class RGLRU(nn.Module):
|
||||
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.width // self.num_heads
|
||||
|
||||
self.recurrent_param = mx.zeros((self.width,))
|
||||
|
||||
self.input_gate_weight = mx.zeros(
|
||||
(self.num_heads, self.head_dim, self.head_dim),
|
||||
)
|
||||
self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim))
|
||||
|
||||
self.recurrent_gate_weight = mx.zeros(
|
||||
(self.num_heads, self.head_dim, self.head_dim),
|
||||
)
|
||||
self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
|
||||
def apply_block_linear(h, w, b):
|
||||
h = h.reshape((B, L, self.num_heads, self.head_dim))
|
||||
h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b
|
||||
return mx.sigmoid(h.flatten(2, 3))
|
||||
|
||||
# Gates for x and a.
|
||||
gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias)
|
||||
gate_a = apply_block_linear(
|
||||
x, self.recurrent_gate_weight, self.recurrent_gate_bias
|
||||
)
|
||||
|
||||
# Compute the parameter `A` of the recurrence.
|
||||
log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param)
|
||||
a = mx.exp(log_a)
|
||||
a_square = mx.exp(2 * log_a)
|
||||
|
||||
# Gate the input.
|
||||
gated_x = x * gate_x
|
||||
|
||||
# Apply gamma normalization to the input.
|
||||
multiplier = mx.sqrt(1 - a_square)
|
||||
if cache is None:
|
||||
multiplier[:, 0, :] = 1.0
|
||||
normalized_x = gated_x * multiplier.astype(x.dtype)
|
||||
|
||||
y, last_h = rnn_scan(
|
||||
x=normalized_x,
|
||||
a=a,
|
||||
h0=cache,
|
||||
)
|
||||
|
||||
return y, last_h
|
||||
|
||||
|
||||
class RecurrentBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
lru_width: int = None,
|
||||
conv1d_temporal_width: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.lru_width = lru_width or width
|
||||
self.conv1d_temporal_width = conv1d_temporal_width
|
||||
|
||||
self.linear_y = nn.Linear(width, self.lru_width)
|
||||
self.linear_x = nn.Linear(width, self.lru_width)
|
||||
self.linear_out = nn.Linear(self.lru_width, width)
|
||||
self.conv_1d = Conv1d(
|
||||
channels=self.lru_width,
|
||||
kernel_size=self.conv1d_temporal_width,
|
||||
)
|
||||
self.rg_lru = RGLRU(
|
||||
width=self.lru_width,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
# y branch.
|
||||
y = self.linear_y(x)
|
||||
y = nn.gelu_approx(y)
|
||||
|
||||
# x branch.
|
||||
x = self.linear_x(x)
|
||||
if cache is None:
|
||||
cache = [None, None]
|
||||
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
|
||||
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
|
||||
|
||||
x = x * y
|
||||
x = self.linear_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LocalAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.scale = (width // num_heads) ** (-0.5)
|
||||
|
||||
self.head_dim = self.width // self.num_heads
|
||||
self.q_proj = nn.Linear(self.width, self.width, bias=False)
|
||||
self.k_proj = nn.Linear(self.width, self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.width, self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.width, self.width, bias=True)
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim // 2,
|
||||
traditional=False,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
|
||||
def __init__(self, width: int, expanded_width: int):
|
||||
super().__init__()
|
||||
self.up_proj = nn.Linear(width, expanded_width // 2)
|
||||
self.gate_proj = nn.Linear(width, expanded_width // 2)
|
||||
self.down_proj = nn.Linear(expanded_width // 2, width)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
gate = self.gate_proj(x)
|
||||
x = self.up_proj(x)
|
||||
return self.down_proj(nn.gelu_approx(gate) * x)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
mlp_expanded_width: int,
|
||||
num_heads: int,
|
||||
attention_window_size: int,
|
||||
temporal_block_type: str,
|
||||
lru_width: Optional[int] = None,
|
||||
conv1d_temporal_width: int = 4,
|
||||
):
|
||||
"""Initializes the residual block.
|
||||
|
||||
Args:
|
||||
width: The width of the block.
|
||||
mlp_expanded_width: The width of the expansion inside the MLP block.
|
||||
num_heads: The number of heads for the Attention or the RG-LRU.
|
||||
attention_window_size: The window size for the local attention block.
|
||||
temporal_block_type: Either "recurrent" or "attention", specifying the
|
||||
type of recurrent block to use.
|
||||
lru_width: The width of the RG-LRU if different from `width`.
|
||||
conv1d_temporal_width: The width of the temporal convolution.
|
||||
"""
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.mlp_expanded_width = mlp_expanded_width
|
||||
self.num_heads = num_heads
|
||||
self.attention_window_size = attention_window_size
|
||||
self.temporal_block_type = temporal_block_type
|
||||
self.lru_width = lru_width
|
||||
self.conv1d_temporal_width = conv1d_temporal_width
|
||||
|
||||
self.temporal_pre_norm = RMSNorm(width)
|
||||
if self.temporal_block_type == "recurrent":
|
||||
self.temporal_block = RecurrentBlock(
|
||||
width=self.width,
|
||||
num_heads=self.num_heads,
|
||||
lru_width=self.lru_width,
|
||||
conv1d_temporal_width=self.conv1d_temporal_width,
|
||||
)
|
||||
|
||||
else:
|
||||
self.temporal_block = LocalAttentionBlock(
|
||||
width=self.width,
|
||||
num_heads=self.num_heads,
|
||||
window_size=self.attention_window_size,
|
||||
)
|
||||
|
||||
self.channel_pre_norm = RMSNorm(width)
|
||||
self.mlp_block = MLPBlock(
|
||||
width=self.width,
|
||||
expanded_width=self.mlp_expanded_width,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
raw_x = x
|
||||
|
||||
inputs_normalized = self.temporal_pre_norm(raw_x)
|
||||
|
||||
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
|
||||
residual = x + raw_x
|
||||
|
||||
x = self.channel_pre_norm(residual)
|
||||
x = self.mlp_block(x)
|
||||
|
||||
x = x + residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Griffin(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
|
||||
block_types = config.block_types
|
||||
|
||||
self.layers = [
|
||||
ResidualBlock(
|
||||
width=config.hidden_size,
|
||||
mlp_expanded_width=config.intermediate_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
attention_window_size=config.attention_window_size,
|
||||
temporal_block_type=block_types[i % len(block_types)],
|
||||
lru_width=None,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tokens,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
x = self.embed_tokens(tokens)
|
||||
if self.scale_by_sqrt_dim:
|
||||
x = x * math.sqrt(x.shape[-1])
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
if block.temporal_block_type != "recurrent":
|
||||
mask_cache = [cache[i]]
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, mask_cache)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
|
||||
return self.final_norm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
self.args = config
|
||||
self.model = Griffin(config)
|
||||
self.model_type = config.model_type
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
tokens: Sequence of input tokens.
|
||||
"""
|
||||
logits = self.model(tokens, mask=mask, cache=cache)
|
||||
if "lm_head" in self:
|
||||
logits = self.lm_head(logits)
|
||||
else:
|
||||
logits = self.model.embed_tokens.as_linear(logits)
|
||||
|
||||
c = self.args.logits_soft_cap
|
||||
if c:
|
||||
logits = mx.tanh(logits / c) * c
|
||||
return logits
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv_1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
if "lm_head.weight" not in weights:
|
||||
self.pop("lm_head")
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
cache = []
|
||||
for layer in self.layers:
|
||||
if layer.temporal_block_type == "recurrent":
|
||||
cache.append(MambaCache())
|
||||
else:
|
||||
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
|
||||
return cache
|
||||
@@ -1,91 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class Llama3RoPE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scaling_config: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.traditional = traditional
|
||||
|
||||
factor = scaling_config["factor"]
|
||||
low_freq_factor = scaling_config.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = scaling_config.get("high_freq_factor", 4.0)
|
||||
old_context_len = scaling_config.get(
|
||||
"original_max_position_embeddings",
|
||||
8192,
|
||||
)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
freqs = base ** (mx.arange(0, dims, 2) / dims)
|
||||
wavelens = 2 * mx.pi * freqs
|
||||
|
||||
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
||||
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
||||
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
|
||||
high_freq_factor - low_freq_factor
|
||||
)
|
||||
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
||||
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"{self.dims}, traditional={self.traditional}, "
|
||||
f"max_position_embeddings={self.max_position_embeddings}"
|
||||
)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
def initialize_rope(
|
||||
dims,
|
||||
base,
|
||||
traditional,
|
||||
scaling_config: Optional[dict] = None,
|
||||
max_position_embeddings: Optional[int] = None,
|
||||
):
|
||||
if scaling_config is not None:
|
||||
rope_type = scaling_config.get("type") or scaling_config.get(
|
||||
"rope_type", "default"
|
||||
)
|
||||
else:
|
||||
rope_type = "default"
|
||||
|
||||
if rope_type in ["default", "linear"]:
|
||||
scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0
|
||||
return nn.RoPE(dims, traditional=traditional, base=base, scale=scale)
|
||||
|
||||
elif rope_type == "llama3":
|
||||
return Llama3RoPE(
|
||||
dims=dims,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scaling_config=scaling_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported RoPE type {rope_type}")
|
||||
@@ -1,211 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
intermediate_size: int
|
||||
rope_theta: float
|
||||
use_qkv_bias: bool
|
||||
partial_rotary_factor: float
|
||||
layer_norm_eps: float
|
||||
use_parallel_residual: bool = False
|
||||
qk_layernorm: bool = False
|
||||
|
||||
|
||||
class LayerNormPerHead(nn.Module):
|
||||
|
||||
def __init__(self, head_dim, num_heads, eps):
|
||||
super().__init__()
|
||||
self.norms = [
|
||||
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
|
||||
]
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
w = mx.stack([n.weight for n in self.norms])
|
||||
return w * mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
bias=config.use_qkv_bias,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
bias=config.use_qkv_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
self.qk_layernorm = config.qk_layernorm
|
||||
if self.qk_layernorm:
|
||||
self.q_layernorm = LayerNormPerHead(
|
||||
self.head_dim, self.num_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
self.k_layernorm = LayerNormPerHead(
|
||||
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
B, L, D = queries.shape
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
|
||||
if self.qk_layernorm:
|
||||
queries = self.q_layernorm(queries)
|
||||
keys = self.k_layernorm(keys)
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
keys = keys.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=scale, mask=mask
|
||||
).astype(values.dtype)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config=config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
if not self.use_parallel_residual:
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
r = self.self_attn(h, mask, cache)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
out = x + r + self.mlp(h)
|
||||
else:
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class StableLM(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, mask, cache=c)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.model = StableLM(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,169 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
norm_epsilon: float = 1e-5
|
||||
vocab_size: int = 49152
|
||||
rope_theta: float = 100000
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // args.num_attention_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
||||
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.norm_epsilon
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Starcoder2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Starcoder2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, mask, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
@@ -1,66 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class SuScaledRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
base: float = 10000.0,
|
||||
max_position_embeddings: int = 131072,
|
||||
original_max_position_embeddings: int = 4096,
|
||||
short_factor: Union[List[float], float] = 1.0,
|
||||
long_factor: Union[List[float], float] = 1.0,
|
||||
short_mscale: float = None,
|
||||
long_mscale: float = None,
|
||||
):
|
||||
"""
|
||||
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated.
|
||||
base (int, optional): Base for the exponential scaling.
|
||||
max_position_embeddings (int, optional): The maximum sequence
|
||||
length that this model was trained with. This is used to determine
|
||||
the size of the original RoPE embeddings when using long scaling.
|
||||
Default: ``131072``.
|
||||
original_max_position_embeddings (int, optional): The maximum
|
||||
sequence length that this model was trained with. This is used to
|
||||
determine the size of the original RoPE embeddings when using long
|
||||
scaling. Default: ``4096``.
|
||||
short_factor (float or list[float], optional): List of scaling
|
||||
factors for sequences of length lesser than
|
||||
``original_max_position_embeddings``. Default: ``1.0``.
|
||||
long_factor (float or list[float], optional): List of scaling
|
||||
factors for sequences of length greater than
|
||||
``original_max_position_embeddings``. Default: ``1.0``.
|
||||
short_mscale (float, optional): Scale the input prior to embedding.
|
||||
long_mscale (float, optional): Scale the input prior to embedding.
|
||||
"""
|
||||
super().__init__()
|
||||
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.scale = long_mscale or math.sqrt(
|
||||
1
|
||||
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||
/ math.log(original_max_position_embeddings)
|
||||
)
|
||||
self.dim = dims
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
x[..., : self.dim] = self.scale * x[..., : self.dim]
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dim,
|
||||
traditional=False,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
@@ -1,167 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class QuantizedSwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_experts, output_dims, input_dims),
|
||||
),
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((num_experts, output_dims))
|
||||
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
@property
|
||||
def input_dims(self):
|
||||
return self.scales.shape[2] * self.group_size
|
||||
|
||||
@property
|
||||
def output_dims(self):
|
||||
return self.weight.shape[1]
|
||||
|
||||
@property
|
||||
def num_experts(self):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.gather_qmm(
|
||||
x,
|
||||
self["weight"],
|
||||
self["scales"],
|
||||
self["biases"],
|
||||
rhs_indices=indices,
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
||||
|
||||
class SwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_experts, output_dims, input_dims),
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((num_experts, output_dims))
|
||||
|
||||
@property
|
||||
def input_dims(self):
|
||||
return self.weight.shape[2]
|
||||
|
||||
@property
|
||||
def output_dims(self):
|
||||
return self.weight.shape[1]
|
||||
|
||||
@property
|
||||
def num_experts(self):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
num_experts, output_dims, input_dims = self.weight.shape
|
||||
ql = QuantizedSwitchLinear(
|
||||
input_dims, output_dims, num_experts, False, group_size, bits
|
||||
)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
|
||||
if "bias" in self:
|
||||
ql.bias = self.bias
|
||||
return ql
|
||||
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=nn.silu,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||
self.activation = activation
|
||||
|
||||
def __call__(self, x, indices) -> mx.array:
|
||||
x = mx.expand_dims(x, (-2, -3))
|
||||
|
||||
x_up = self.up_proj(x, indices)
|
||||
x_gate = self.gate_proj(x, indices)
|
||||
x = self.down_proj(self.activation(x_gate) * x_up, indices)
|
||||
|
||||
return x.squeeze(-2)
|
||||
|
||||
|
||||
class SwitchMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=nn.gelu_approx,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||
self.activation = activation
|
||||
|
||||
def __call__(self, x, indices) -> mx.array:
|
||||
x = mx.expand_dims(x, (-2, -3))
|
||||
|
||||
x = self.fc1(x, indices)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x, indices)
|
||||
|
||||
return x.squeeze(-2)
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
mlx>=0.22.0
|
||||
numpy
|
||||
transformers[sentencepiece]>=4.39.3
|
||||
protobuf
|
||||
pyyaml
|
||||
jinja2
|
||||
@@ -1,257 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def make_sampler(
|
||||
temp: float = 0.0,
|
||||
top_p: float = 0.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
top_k: int = -1,
|
||||
) -> Callable[mx.array, mx.array]:
|
||||
"""
|
||||
Make a sampler function for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
top_k (int, optional): The top k tokens ranked by probability to constrain
|
||||
the sampling to.
|
||||
|
||||
Returns:
|
||||
Callable[mx.array, mx.array]:
|
||||
A sampler which takes log-probabilities and returns tokens.
|
||||
"""
|
||||
if temp == 0:
|
||||
return lambda x: mx.argmax(x, axis=-1)
|
||||
|
||||
# Create sampler chain
|
||||
sampling_methods = []
|
||||
if top_k > 0:
|
||||
sampling_methods.append(lambda x: apply_top_k(x, top_k))
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
sampling_methods.append(lambda x: apply_top_p(x, top_p))
|
||||
if min_p != 0.0:
|
||||
sampling_methods.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep))
|
||||
|
||||
# Apply the sampling methods
|
||||
def sampler(logits):
|
||||
for method in sampling_methods:
|
||||
logits = method(logits)
|
||||
|
||||
# Return the sampled token
|
||||
return categorical_sampling(logits, temp)
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
def make_logits_processors(
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
):
|
||||
"""
|
||||
Make logits processors for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
|
||||
Returns:
|
||||
List[Callable[[mx.array, mx.array], mx.array]]:
|
||||
A list of logits processors. Each processor in the list is a
|
||||
callable which takes an array of tokens and an array of logits
|
||||
and returns the updated logits.
|
||||
"""
|
||||
logits_processors = []
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
|
||||
def logit_bias_processor(_, logits):
|
||||
logits[:, indices] += values
|
||||
return logits
|
||||
|
||||
logits_processors.append(logit_bias_processor)
|
||||
|
||||
if repetition_penalty and repetition_penalty != 0.0:
|
||||
logits_processors.append(
|
||||
make_repetition_penalty(repetition_penalty, repetition_context_size)
|
||||
)
|
||||
return logits_processors
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_top_k(
|
||||
logprobs: mx.array,
|
||||
top_k: int,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Sample from only the top K tokens ranked by probability.
|
||||
|
||||
Args:
|
||||
logprobs: A vector of log probabilities.
|
||||
top_k (int): Top k tokens to sample from.
|
||||
"""
|
||||
vocab_size = logprobs.shape[-1]
|
||||
if not isinstance(top_k, int) or not (0 < top_k < vocab_size):
|
||||
raise ValueError(
|
||||
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
|
||||
f" but is {top_k}."
|
||||
)
|
||||
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
|
||||
masked_logprobs = mx.put_along_axis(
|
||||
logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1
|
||||
)
|
||||
return masked_logprobs
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_min_p(
|
||||
logprobs: mx.array,
|
||||
min_p: float,
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Apply min-p sampling to the logprobs.
|
||||
|
||||
Min-p keeps all tokens that are above a minimum probability, scaled by the
|
||||
probability of the most likely token. As a result, the filter is more
|
||||
aggressive given a very high-probability token.
|
||||
|
||||
Args:
|
||||
logprobs: A vector of log probabilities.
|
||||
min_p (float): Minimum token probability. Typical values are in the
|
||||
0.01-0.2 range, comparably selective as setting `top_p` in the
|
||||
0.99-0.8 range.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered. Default: ``1``.
|
||||
|
||||
"""
|
||||
if not (0 <= min_p <= 1.0):
|
||||
raise ValueError(
|
||||
f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}"
|
||||
)
|
||||
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
|
||||
raise ValueError(
|
||||
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
|
||||
)
|
||||
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
||||
|
||||
# Indices sorted in decreasing order
|
||||
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
||||
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
||||
|
||||
# Top probability
|
||||
top_logprobs = sorted_logprobs[:, 0:1]
|
||||
|
||||
# Calculate the min_p threshold
|
||||
scaled_min_p = top_logprobs + math.log(min_p)
|
||||
|
||||
# Mask tokens that have a probability less than the scaled min_p
|
||||
tokens_to_remove = sorted_logprobs < scaled_min_p
|
||||
tokens_to_remove[..., :min_tokens_to_keep] = False
|
||||
|
||||
# Create pool of tokens with probability less than scaled min_p
|
||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||
|
||||
# Create a mapping to rearrange back to original indices
|
||||
# Use argsort of sorted_indices to get the inverse permutation
|
||||
inverse_indices = mx.argsort(sorted_indices, axis=-1)
|
||||
|
||||
# Rearrange selected_logprobs back to original order
|
||||
original_order_logprobs = mx.take_along_axis(
|
||||
selected_logprobs, inverse_indices, axis=-1
|
||||
)
|
||||
|
||||
return original_order_logprobs
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_top_p(logits: mx.array, top_p: float) -> mx.array:
|
||||
"""
|
||||
Apply top-p (nucleus) sampling to logits.
|
||||
|
||||
Args:
|
||||
logits: The logits from the model's output.
|
||||
top_p: The cumulative probability threshold for top-p filtering.
|
||||
Returns:
|
||||
token selected based on the top-p criterion.
|
||||
"""
|
||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||
probs = mx.softmax(logits, axis=-1)
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=-1)
|
||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
||||
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
|
||||
# select tokens with cumulative probs below threshold
|
||||
top_probs = mx.where(
|
||||
cumulative_probs > 1 - top_p,
|
||||
sorted_probs,
|
||||
0,
|
||||
)
|
||||
|
||||
# Create a mapping to rearrange back to original indices
|
||||
# Use argsort of sorted_indices to get the inverse permutation
|
||||
inverse_indices = mx.argsort(sorted_indices, axis=-1)
|
||||
|
||||
# Rearrange top_probs back to original order
|
||||
original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1)
|
||||
|
||||
# Convert back to logits and return
|
||||
return mx.log(original_order_probs)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def categorical_sampling(logits, temp):
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
|
||||
def make_repetition_penalty(penalty: float, context_size: int = 20):
|
||||
"""
|
||||
Make repetition penalty processor.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
context_size (int): The number of previous tokens to use.
|
||||
Default: ``20``.
|
||||
|
||||
Returns:
|
||||
Callable[[mx.array, List[int]], mx.array]:
|
||||
The repetition penalty processor.
|
||||
"""
|
||||
if penalty < 0 or not isinstance(penalty, (int, float)):
|
||||
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||
|
||||
def repetition_penalty_processor(tokens, logits):
|
||||
if len(tokens) > 0:
|
||||
tokens = tokens[-context_size:]
|
||||
selected_logits = logits[:, tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0,
|
||||
selected_logits * penalty,
|
||||
selected_logits / penalty,
|
||||
)
|
||||
logits[:, tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
return repetition_penalty_processor
|
||||
@@ -1,785 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
from ._version import __version__
|
||||
from .models.cache import make_prompt_cache
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .utils import load, stream_generate
|
||||
|
||||
|
||||
def get_system_fingerprint():
|
||||
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
|
||||
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
|
||||
|
||||
|
||||
class StopCondition(NamedTuple):
|
||||
stop_met: bool
|
||||
trim_length: int
|
||||
|
||||
|
||||
def stopping_criteria(
|
||||
tokens: List[int],
|
||||
stop_id_sequences: List[List[int]],
|
||||
eos_token_id: Union[int, None],
|
||||
) -> StopCondition:
|
||||
"""
|
||||
Determines whether the token generation should stop based on predefined
|
||||
conditions.
|
||||
|
||||
Args:
|
||||
tokens (List[int]): The current sequence of generated tokens.
|
||||
stop_id_sequences (List[List[[int]]): A list of integer lists, each
|
||||
representing a sequence of token IDs. If the end of the `tokens`
|
||||
list matches any of these sequences, the generation should stop.
|
||||
eos_token_id (Union[int, None]): The token ID that represents the
|
||||
end-of-sequence. If the last token in `tokens` matches this, the
|
||||
generation should stop.
|
||||
|
||||
Returns:
|
||||
StopCondition: A named tuple indicating whether the stop condition has
|
||||
been met (`stop_met`) and how many tokens should be trimmed from the
|
||||
end if it has (`trim_length`).
|
||||
"""
|
||||
if tokens and tokens[-1] == eos_token_id:
|
||||
return StopCondition(stop_met=True, trim_length=0)
|
||||
|
||||
for stop_ids in stop_id_sequences:
|
||||
if len(tokens) >= len(stop_ids):
|
||||
if tokens[-len(stop_ids) :] == stop_ids:
|
||||
return StopCondition(stop_met=True, trim_length=len(stop_ids))
|
||||
|
||||
return StopCondition(stop_met=False, trim_length=0)
|
||||
|
||||
|
||||
def sequence_overlap(s1: Sequence, s2: Sequence) -> bool:
|
||||
"""
|
||||
Checks if a suffix of s1 has overlap with a prefix of s2
|
||||
|
||||
Args:
|
||||
s1 (Sequence): The first sequence
|
||||
s2 (Sequence): The second sequence
|
||||
|
||||
Returns:
|
||||
bool: If the two sequences have overlap
|
||||
"""
|
||||
max_overlap = min(len(s1), len(s2))
|
||||
return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1))
|
||||
|
||||
|
||||
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
default_role_mapping = {
|
||||
"system_prompt": (
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant follows the given rules no matter what."
|
||||
),
|
||||
"system": "ASSISTANT's RULE: ",
|
||||
"user": "USER: ",
|
||||
"assistant": "ASSISTANT: ",
|
||||
"stop": "\n",
|
||||
}
|
||||
role_mapping = role_mapping if role_mapping is not None else default_role_mapping
|
||||
|
||||
prompt = ""
|
||||
for line in messages:
|
||||
role_prefix = role_mapping.get(line["role"], "")
|
||||
stop = role_mapping.get("stop", "")
|
||||
content = line.get("content", "")
|
||||
prompt += f"{role_prefix}{content}{stop}"
|
||||
|
||||
prompt += role_mapping.get("assistant", "")
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
def process_message_content(messages):
|
||||
"""
|
||||
Convert message content to a format suitable for `apply_chat_template`.
|
||||
|
||||
The function operates on messages in place. It converts the 'content' field
|
||||
to a string instead of a list of text fragments.
|
||||
|
||||
Args:
|
||||
message_list (list): A list of dictionaries, where each dictionary may
|
||||
have a 'content' key containing a list of dictionaries with 'type' and
|
||||
'text' keys.
|
||||
|
||||
Raises:
|
||||
ValueError: If the 'content' type is not supported or if 'text' is missing.
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
text_fragments = [
|
||||
fragment["text"] for fragment in content if fragment["type"] == "text"
|
||||
]
|
||||
if len(text_fragments) != len(content):
|
||||
raise ValueError("Only 'text' content type is supported.")
|
||||
message["content"] = "".join(text_fragments)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptCache:
|
||||
cache: List[Any] = field(default_factory=list)
|
||||
model_key: Tuple[str, Optional[str]] = ("", None)
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class ModelProvider:
|
||||
def __init__(self, cli_args: argparse.Namespace):
|
||||
"""Load models on demand and persist them across the whole process."""
|
||||
self.cli_args = cli_args
|
||||
self.model_key = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Preload the default model if it is provided
|
||||
if self.cli_args.model is not None:
|
||||
self.load("default_model")
|
||||
|
||||
def _validate_model_path(self, model_path: str):
|
||||
model_path = Path(model_path)
|
||||
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
|
||||
raise RuntimeError(
|
||||
"Local models must be relative to the current working dir."
|
||||
)
|
||||
|
||||
# Added in adapter_path to load dynamically
|
||||
def load(self, model_path, adapter_path=None):
|
||||
if self.model_key == (model_path, adapter_path):
|
||||
return self.model, self.tokenizer
|
||||
|
||||
# Remove the old model if it exists.
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.model_key = None
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {
|
||||
"trust_remote_code": True if self.cli_args.trust_remote_code else None
|
||||
}
|
||||
if self.cli_args.chat_template:
|
||||
tokenizer_config["chat_template"] = self.cli_args.chat_template
|
||||
|
||||
if model_path == "default_model" and self.cli_args.model is not None:
|
||||
model, tokenizer = load(
|
||||
self.cli_args.model,
|
||||
adapter_path=(
|
||||
adapter_path if adapter_path else self.cli_args.adapter_path
|
||||
), # if the user doesn't change the model but adds an adapter path
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
else:
|
||||
self._validate_model_path(model_path)
|
||||
model, tokenizer = load(
|
||||
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
if self.cli_args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
self.model_key = (model_path, adapter_path)
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
return self.model, self.tokenizer
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: ModelProvider,
|
||||
*args,
|
||||
prompt_cache: Optional[PromptCache] = None,
|
||||
system_fingerprint: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create static request specific metadata
|
||||
"""
|
||||
self.created = int(time.time())
|
||||
self.model_provider = model_provider
|
||||
self.prompt_cache = prompt_cache or PromptCache()
|
||||
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _set_cors_headers(self):
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
self.send_header("Access-Control-Allow-Methods", "*")
|
||||
self.send_header("Access-Control-Allow-Headers", "*")
|
||||
|
||||
def _set_completion_headers(self, status_code: int = 200):
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self._set_cors_headers()
|
||||
|
||||
def _set_stream_headers(self, status_code: int = 200):
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self._set_cors_headers()
|
||||
|
||||
def do_OPTIONS(self):
|
||||
self._set_completion_headers(204)
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self):
|
||||
"""
|
||||
Respond to a POST request from a client.
|
||||
"""
|
||||
endpoints = {
|
||||
"/v1/completions": self.handle_text_completions,
|
||||
"/v1/chat/completions": self.handle_chat_completions,
|
||||
"/chat/completions": self.handle_chat_completions,
|
||||
}
|
||||
|
||||
if self.path not in endpoints:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not Found")
|
||||
return
|
||||
|
||||
# Fetch and parse request body
|
||||
content_length = int(self.headers["Content-Length"])
|
||||
raw_body = self.rfile.read(content_length)
|
||||
self.body = json.loads(raw_body.decode())
|
||||
indent = "\t" # Backslashes can't be inside of f-strings
|
||||
logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}")
|
||||
assert isinstance(
|
||||
self.body, dict
|
||||
), f"Request should be dict, but got {type(self.body)}"
|
||||
|
||||
# Extract request parameters from the body
|
||||
self.stream = self.body.get("stream", False)
|
||||
self.stream_options = self.body.get("stream_options", None)
|
||||
self.requested_model = self.body.get("model", "default_model")
|
||||
self.adapter = self.body.get("adapters", None)
|
||||
self.max_tokens = self.body.get("max_completion_tokens", None)
|
||||
if self.max_tokens is None:
|
||||
self.max_tokens = self.body.get("max_tokens", 512)
|
||||
self.temperature = self.body.get("temperature", 0.0)
|
||||
self.top_p = self.body.get("top_p", 1.0)
|
||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||
self.logit_bias = self.body.get("logit_bias", None)
|
||||
self.logprobs = self.body.get("logprobs", -1)
|
||||
self.validate_model_parameters()
|
||||
|
||||
# Load the model if needed
|
||||
try:
|
||||
self.model, self.tokenizer = self.model_provider.load(
|
||||
self.requested_model, self.adapter
|
||||
)
|
||||
except:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not Found")
|
||||
return
|
||||
|
||||
# Get stop id sequences, if provided
|
||||
stop_words = self.body.get("stop")
|
||||
stop_words = stop_words or []
|
||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||
stop_id_sequences = [
|
||||
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
||||
for stop_word in stop_words
|
||||
]
|
||||
|
||||
# Send header type
|
||||
(
|
||||
self._set_stream_headers(200)
|
||||
if self.stream
|
||||
else self._set_completion_headers(200)
|
||||
)
|
||||
|
||||
# Call endpoint specific method
|
||||
prompt = endpoints[self.path]()
|
||||
self.handle_completion(prompt, stop_id_sequences)
|
||||
|
||||
def validate_model_parameters(self):
|
||||
"""
|
||||
Validate the model parameters passed in the request for the correct types and values.
|
||||
"""
|
||||
if not isinstance(self.stream, bool):
|
||||
raise ValueError("stream must be a boolean")
|
||||
|
||||
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
|
||||
raise ValueError("max_tokens must be a non-negative integer")
|
||||
|
||||
if not isinstance(self.temperature, (float, int)) or self.temperature < 0:
|
||||
raise ValueError("temperature must be a non-negative float")
|
||||
|
||||
if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1:
|
||||
raise ValueError("top_p must be a float between 0 and 1")
|
||||
|
||||
if (
|
||||
not isinstance(self.repetition_penalty, (float, int))
|
||||
or self.repetition_penalty < 0
|
||||
):
|
||||
raise ValueError("repetition_penalty must be a non-negative float")
|
||||
|
||||
if self.logprobs != -1 and not (0 < self.logprobs <= 10):
|
||||
raise ValueError(
|
||||
f"logprobs must be between 1 and 10 but got {self.logprobs:,}"
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(self.repetition_context_size, int)
|
||||
or self.repetition_context_size < 0
|
||||
):
|
||||
raise ValueError("repetition_context_size must be a non-negative integer")
|
||||
|
||||
if self.logit_bias is not None:
|
||||
if not isinstance(self.logit_bias, dict):
|
||||
raise ValueError("logit_bias must be a dict of int to float")
|
||||
|
||||
try:
|
||||
self.logit_bias = {int(k): v for k, v in self.logit_bias.items()}
|
||||
except ValueError:
|
||||
raise ValueError("logit_bias must be a dict of int to float")
|
||||
|
||||
if not isinstance(self.requested_model, str):
|
||||
raise ValueError("model must be a string")
|
||||
if self.adapter is not None and not isinstance(self.adapter, str):
|
||||
raise ValueError("adapter must be a string")
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
text: str,
|
||||
finish_reason: Union[Literal["length", "stop"], None],
|
||||
prompt_token_count: Optional[int] = None,
|
||||
completion_token_count: Optional[int] = None,
|
||||
token_logprobs: Optional[List[float]] = None,
|
||||
top_tokens: Optional[List[Dict[int, float]]] = None,
|
||||
tokens: Optional[List[int]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a single response packet based on response type (stream or
|
||||
not), completion type and parameters.
|
||||
|
||||
Args:
|
||||
text (str): Text generated by model
|
||||
finish_reason (Union[Literal["length", "stop"], None]): The reason the
|
||||
response is being sent: "length", "stop" or `None`.
|
||||
prompt_token_count (Optional[int]): The number of tokens in the prompt,
|
||||
used to populate the "usage" field (not used when stream).
|
||||
completion_token_count (Optional[int]): The number of tokens in the
|
||||
response, used to populate the "usage" field (not used when stream).
|
||||
token_logprobs (Optional[List[float]]): The log probabilities per token,
|
||||
in token order.
|
||||
top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping
|
||||
tokens to logprobs for the top N tokens at each token position.
|
||||
tokens (Optional[List[int]]): List of tokens to return with logprobs structure
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the response, in the same format as
|
||||
OpenAI's API.
|
||||
"""
|
||||
token_logprobs = token_logprobs if token_logprobs else []
|
||||
top_logprobs = top_tokens if top_tokens else []
|
||||
|
||||
# Static response
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": self.system_fingerprint,
|
||||
"object": self.object_type,
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"logprobs": {
|
||||
"token_logprobs": token_logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"tokens": tokens,
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if not self.stream:
|
||||
if not (
|
||||
isinstance(prompt_token_count, int)
|
||||
and isinstance(completion_token_count, int)
|
||||
):
|
||||
raise ValueError(
|
||||
"Response type is complete, but token counts not provided"
|
||||
)
|
||||
|
||||
response["usage"] = {
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": prompt_token_count + completion_token_count,
|
||||
}
|
||||
|
||||
choice = response["choices"][0]
|
||||
|
||||
# Add dynamic response
|
||||
if self.object_type.startswith("chat.completion"):
|
||||
key_name = "delta" if self.stream else "message"
|
||||
choice[key_name] = {"role": "assistant", "content": text}
|
||||
elif self.object_type == "text_completion":
|
||||
choice.update(text=text)
|
||||
else:
|
||||
ValueError(f"Unsupported response type: {self.object_type}")
|
||||
|
||||
return response
|
||||
|
||||
def get_prompt_cache(self, prompt):
|
||||
cache_len = len(self.prompt_cache.tokens)
|
||||
if (
|
||||
self.prompt_cache.model_key != self.model_provider.model_key
|
||||
or cache_len >= len(prompt)
|
||||
or self.prompt_cache.tokens != prompt[:cache_len]
|
||||
):
|
||||
self.prompt_cache.model_key = self.model_provider.model_key
|
||||
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
|
||||
else:
|
||||
prompt = prompt[cache_len:]
|
||||
self.prompt_cache.tokens.extend(prompt)
|
||||
return prompt
|
||||
|
||||
def handle_completion(
|
||||
self,
|
||||
prompt: List[int],
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Generate a response to a prompt and send it to the client in a single batch.
|
||||
|
||||
Args:
|
||||
prompt (List[int]): The tokenized prompt.
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||
to the stopping_criteria function
|
||||
"""
|
||||
tokens = []
|
||||
finish_reason = "length"
|
||||
stop_sequence_suffix = None
|
||||
if self.stream:
|
||||
self.end_headers()
|
||||
logging.debug(f"Starting stream:")
|
||||
else:
|
||||
logging.debug(f"Starting completion:")
|
||||
token_logprobs = []
|
||||
top_tokens = []
|
||||
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
|
||||
text = ""
|
||||
tic = time.perf_counter()
|
||||
sampler = make_sampler(self.temperature, top_p=self.top_p)
|
||||
logits_processors = make_logits_processors(
|
||||
self.logit_bias, self.repetition_penalty, self.repetition_context_size
|
||||
)
|
||||
for gen_response in stream_generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
):
|
||||
segment = gen_response.text
|
||||
text += segment
|
||||
logging.debug(text)
|
||||
token = gen_response.token
|
||||
logprobs = gen_response.logprobs
|
||||
tokens.append(token)
|
||||
|
||||
if self.logprobs > 0:
|
||||
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
|
||||
top_indices = sorted_indices[: self.logprobs]
|
||||
top_logprobs = logprobs[top_indices]
|
||||
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
|
||||
top_tokens.append(tuple(top_token_info))
|
||||
|
||||
token_logprobs.append(logprobs[token].item())
|
||||
|
||||
stop_condition = stopping_criteria(
|
||||
tokens, stop_id_sequences, self.tokenizer.eos_token_id
|
||||
)
|
||||
if stop_condition.stop_met:
|
||||
finish_reason = "stop"
|
||||
if stop_condition.trim_length:
|
||||
stop_sequence_suffix = self.tokenizer.decode(
|
||||
tokens[-stop_condition.trim_length :]
|
||||
)
|
||||
text = text[: -len(stop_sequence_suffix)]
|
||||
break
|
||||
|
||||
if self.stream:
|
||||
# If the end of tokens overlaps with a stop sequence, generate new
|
||||
# tokens until we know if the stop sequence is hit or not
|
||||
if any(
|
||||
(
|
||||
sequence_overlap(tokens, sequence)
|
||||
for sequence in stop_id_sequences
|
||||
)
|
||||
):
|
||||
continue
|
||||
elif segment:
|
||||
response = self.generate_response(segment, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
self.prompt_cache.tokens.extend(tokens)
|
||||
|
||||
logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
|
||||
logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
|
||||
|
||||
if self.stream:
|
||||
response = self.generate_response(segment, finish_reason)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
if self.stream_options is not None and self.stream_options["include_usage"]:
|
||||
response = self.completion_usage_response(len(prompt), len(tokens))
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
self.wfile.write("data: [DONE]\n\n".encode())
|
||||
self.wfile.flush()
|
||||
else:
|
||||
response = self.generate_response(
|
||||
text,
|
||||
finish_reason,
|
||||
len(prompt),
|
||||
len(tokens),
|
||||
token_logprobs=token_logprobs,
|
||||
top_tokens=top_tokens,
|
||||
tokens=tokens,
|
||||
)
|
||||
response_json = json.dumps(response).encode()
|
||||
indent = "\t" # Backslashes can't be inside of f-strings
|
||||
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
||||
|
||||
# Send an additional Content-Length header when it is known
|
||||
self.send_header("Content-Length", str(len(response_json)))
|
||||
self.end_headers()
|
||||
self.wfile.write(response_json)
|
||||
self.wfile.flush()
|
||||
|
||||
def completion_usage_response(
|
||||
self,
|
||||
prompt_token_count: Optional[int] = None,
|
||||
completion_token_count: Optional[int] = None,
|
||||
):
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": self.system_fingerprint,
|
||||
"object": "chat.completion",
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
"choices": [],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": prompt_token_count + completion_token_count,
|
||||
},
|
||||
}
|
||||
return response
|
||||
|
||||
def handle_chat_completions(self) -> List[int]:
|
||||
"""
|
||||
Handle a chat completion request.
|
||||
|
||||
Returns:
|
||||
mx.array: A mx.array of the tokenized prompt from the request body
|
||||
"""
|
||||
body = self.body
|
||||
assert "messages" in body, "Request did not contain messages"
|
||||
|
||||
# Determine response type
|
||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||
if self.tokenizer.chat_template:
|
||||
messages = body["messages"]
|
||||
process_message_content(messages)
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
body.get("tools", None),
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
def handle_text_completions(self) -> List[int]:
|
||||
"""
|
||||
Handle a text completion request.
|
||||
|
||||
Returns:
|
||||
mx.array: A mx.array of the tokenized prompt from the request body
|
||||
"""
|
||||
# Determine response type
|
||||
self.request_id = f"cmpl-{uuid.uuid4()}"
|
||||
self.object_type = "text_completion"
|
||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||
return self.tokenizer.encode(self.body["prompt"])
|
||||
|
||||
def do_GET(self):
|
||||
"""
|
||||
Respond to a GET request from a client.
|
||||
"""
|
||||
if self.path == "/v1/models":
|
||||
self.handle_models_request()
|
||||
else:
|
||||
self._set_completion_headers(404)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"Not Found")
|
||||
|
||||
def handle_models_request(self):
|
||||
"""
|
||||
Handle a GET request for the /v1/models endpoint.
|
||||
"""
|
||||
self._set_completion_headers(200)
|
||||
self.end_headers()
|
||||
|
||||
# Scan the cache directory for downloaded mlx models
|
||||
hf_cache_info = scan_cache_dir()
|
||||
downloaded_models = [
|
||||
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
|
||||
]
|
||||
|
||||
# Create a list of available models
|
||||
models = [
|
||||
{
|
||||
"id": repo.repo_id,
|
||||
"object": "model",
|
||||
"created": self.created,
|
||||
}
|
||||
for repo in downloaded_models
|
||||
]
|
||||
|
||||
response = {"object": "list", "data": models}
|
||||
|
||||
response_json = json.dumps(response).encode()
|
||||
self.wfile.write(response_json)
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
def run(
|
||||
host: str,
|
||||
port: int,
|
||||
model_provider: ModelProvider,
|
||||
server_class=HTTPServer,
|
||||
handler_class=APIHandler,
|
||||
):
|
||||
server_address = (host, port)
|
||||
prompt_cache = PromptCache()
|
||||
httpd = server_class(
|
||||
server_address,
|
||||
lambda *args, **kwargs: handler_class(
|
||||
model_provider,
|
||||
prompt_cache=prompt_cache,
|
||||
system_fingerprint=get_system_fingerprint(),
|
||||
*args,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
warnings.warn(
|
||||
"mlx_lm.server is not recommended for production as "
|
||||
"it only implements basic security checks."
|
||||
)
|
||||
logging.info(f"Starting httpd at {host} on port {port}...")
|
||||
httpd.serve_forever()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MLX Http Server.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the MLX model weights, tokenizer, and config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host for the HTTP server (default: 127.0.0.1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="Port for the HTTP server (default: 8080)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Enable trusting remote code for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the logging level (default: INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-limit-gb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the MLX cache limit in GB",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify a chat template for the tokenizer",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-default-chat-template",
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level.upper(), None),
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
if args.cache_limit_gb is not None:
|
||||
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
run(args.host, args.port, ModelProvider(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,376 +0,0 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class StreamingDetokenizer:
|
||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||
|
||||
Example usage is as follows:
|
||||
|
||||
detokenizer = ...
|
||||
|
||||
# Reset the tokenizer state
|
||||
detokenizer.reset()
|
||||
|
||||
for token in generate(...):
|
||||
detokenizer.add_token(token.item())
|
||||
|
||||
# Contains the whole text so far. Some tokens may not be included
|
||||
# since it contains whole words usually.
|
||||
detokenizer.text
|
||||
|
||||
# Contains the printable segment (usually a word) since the last
|
||||
# time it was accessed
|
||||
detokenizer.last_segment
|
||||
|
||||
# Contains all the tokens added so far
|
||||
detokenizer.tokens
|
||||
|
||||
# Make sure that we detokenize any remaining tokens
|
||||
detokenizer.finalize()
|
||||
|
||||
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
|
||||
"""
|
||||
|
||||
__slots__ = ("text", "tokens", "offset")
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_token(self, token):
|
||||
raise NotImplementedError()
|
||||
|
||||
def finalize(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def last_segment(self):
|
||||
"""Return the last segment of readable text since last time this property was accessed."""
|
||||
text = self.text
|
||||
segment = text[self.offset :]
|
||||
self.offset = len(text)
|
||||
return segment
|
||||
|
||||
|
||||
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
||||
implementation and should work with every tokenizer.
|
||||
|
||||
Its complexity is O(T^2) where T is the longest line since it will
|
||||
repeatedly detokenize the same tokens until a new line is generated.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self._tokenizer = tokenizer
|
||||
self._tokenizer.decode([0])
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.offset = 0
|
||||
self.tokens = []
|
||||
self._text = ""
|
||||
self._current_tokens = []
|
||||
self._current_text = ""
|
||||
|
||||
def add_token(self, token):
|
||||
self._current_tokens.append(token)
|
||||
self.tokens.append(token)
|
||||
|
||||
def finalize(self):
|
||||
self._text += self._tokenizer.decode(self._current_tokens)
|
||||
self._current_tokens = []
|
||||
self._current_text = ""
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
if self._current_tokens:
|
||||
self._current_text = self._tokenizer.decode(self._current_tokens)
|
||||
if (
|
||||
self._tokenizer.clean_up_tokenization_spaces
|
||||
and self._current_text[-1] == " "
|
||||
):
|
||||
self._current_text = self._current_text[:-1]
|
||||
if self._current_text and self._current_text[-1] == "\n":
|
||||
self._text += self._current_text
|
||||
self._current_tokens.clear()
|
||||
self._current_text = ""
|
||||
return self._text + self._current_text
|
||||
|
||||
|
||||
class SPMStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for SPM models.
|
||||
|
||||
It adds tokens to the text if the next token starts with the special SPM
|
||||
underscore which results in linear complexity.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, trim_space=True):
|
||||
self.trim_space = trim_space
|
||||
self._sep = "\u2581".encode()
|
||||
|
||||
# Extract the tokens in a list from id to text
|
||||
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
|
||||
for value, tokenid in tokenizer.vocab.items():
|
||||
if value.startswith("<0x"):
|
||||
# Replace bytes with their value
|
||||
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
|
||||
else:
|
||||
self.tokenmap[tokenid] = value.encode()
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.offset = 0
|
||||
self._unflushed = b""
|
||||
self.text = ""
|
||||
self.tokens = []
|
||||
|
||||
def _try_flush(self, force=False):
|
||||
text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace")
|
||||
if not force and text.endswith("\ufffd"):
|
||||
return
|
||||
if not self.text and self.trim_space and text and text[0] == " ":
|
||||
text = text[1:]
|
||||
self.text += text
|
||||
self._unflushed = b""
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
self._unflushed += v
|
||||
self._try_flush()
|
||||
|
||||
def finalize(self):
|
||||
self._try_flush(force=True)
|
||||
self._unflushed = b""
|
||||
|
||||
|
||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""A streaming detokenizer for OpenAI style BPE models.
|
||||
|
||||
It adds tokens to the text if the next token starts with a space similar to
|
||||
the SPM detokenizer.
|
||||
"""
|
||||
|
||||
_byte_decoder = None
|
||||
_space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re")
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
|
||||
|
||||
# Extract the tokens in a list from id to text
|
||||
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||
for value, tokenid in tokenizer.vocab.items():
|
||||
self.tokenmap[tokenid] = value
|
||||
|
||||
self.reset()
|
||||
|
||||
# Make the BPE byte decoder from
|
||||
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||
self.make_byte_decoder()
|
||||
|
||||
def reset(self):
|
||||
self.offset = 0
|
||||
self._unflushed = ""
|
||||
self.text = ""
|
||||
self.tokens = []
|
||||
|
||||
def _decode_bytes(self, seq):
|
||||
barr = bytearray()
|
||||
for c in seq:
|
||||
res = self._byte_decoder.get(c, False)
|
||||
if res:
|
||||
barr.append(res)
|
||||
else:
|
||||
barr.extend(bytes(c, "utf-8"))
|
||||
return barr.decode("utf-8", "replace")
|
||||
|
||||
def _maybe_trim_space(self, current_text):
|
||||
if len(current_text) == 0:
|
||||
return current_text
|
||||
elif current_text[0] != " ":
|
||||
return current_text
|
||||
elif not self.text:
|
||||
return current_text[1:]
|
||||
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
|
||||
return current_text[1:]
|
||||
return current_text
|
||||
|
||||
def add_token(self, token):
|
||||
self.tokens.append(token)
|
||||
v = self.tokenmap[token]
|
||||
self._unflushed += v
|
||||
text = self._decode_bytes(self._unflushed)
|
||||
|
||||
# For multi-byte utf-8 wait until they are complete
|
||||
# For single spaces wait until the next token to clean it if needed
|
||||
if not text.endswith("\ufffd") and not (
|
||||
len(v) == 1 and self._byte_decoder[v[0]] == 32
|
||||
):
|
||||
self.text += self._maybe_trim_space(text)
|
||||
self._unflushed = ""
|
||||
|
||||
def finalize(self):
|
||||
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
||||
"utf-8",
|
||||
"replace",
|
||||
)
|
||||
self.text += self._maybe_trim_space(current_text)
|
||||
self._unflushed = ""
|
||||
|
||||
@classmethod
|
||||
def make_byte_decoder(cls):
|
||||
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
||||
if cls._byte_decoder is not None:
|
||||
return
|
||||
|
||||
char_to_bytes = {}
|
||||
limits = [
|
||||
0,
|
||||
ord("!"),
|
||||
ord("~") + 1,
|
||||
ord("¡"),
|
||||
ord("¬") + 1,
|
||||
ord("®"),
|
||||
ord("ÿ") + 1,
|
||||
]
|
||||
n = 0
|
||||
for i, (start, stop) in enumerate(zip(limits, limits[1:])):
|
||||
if i % 2 == 0:
|
||||
for b in range(start, stop):
|
||||
char_to_bytes[chr(2**8 + n)] = b
|
||||
n += 1
|
||||
else:
|
||||
for b in range(start, stop):
|
||||
char_to_bytes[chr(b)] = b
|
||||
cls._byte_decoder = char_to_bytes
|
||||
|
||||
|
||||
class TokenizerWrapper:
|
||||
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
||||
|
||||
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||
huggingface tokenizer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None
|
||||
):
|
||||
self._tokenizer = tokenizer
|
||||
self._detokenizer = detokenizer_class(tokenizer)
|
||||
self._eos_token_ids = (
|
||||
set(eos_token_ids)
|
||||
if eos_token_ids is not None
|
||||
else {tokenizer.eos_token_id}
|
||||
)
|
||||
|
||||
def add_eos_token(self, token: str):
|
||||
token_id = None
|
||||
try:
|
||||
token_id = int(token)
|
||||
except ValueError:
|
||||
token_id = self._tokenizer.convert_tokens_to_ids(token)
|
||||
|
||||
if token_id is None:
|
||||
raise ValueError(f"'{token}' is not a token for this tokenizer")
|
||||
|
||||
self._eos_token_ids.add(token_id)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr == "detokenizer":
|
||||
return self._detokenizer
|
||||
elif attr == "eos_token_ids":
|
||||
return self._eos_token_ids
|
||||
elif attr.startswith("_"):
|
||||
return self.__getattribute__(attr)
|
||||
else:
|
||||
return getattr(self._tokenizer, attr)
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
if attr in {"detokenizer", "eos_token_ids"}:
|
||||
if attr == "detokenizer":
|
||||
raise AttributeError("Cannot set the detokenizer.")
|
||||
elif attr == "eos_token_ids":
|
||||
self._eos_token_ids = set(value) if value is not None else set()
|
||||
elif attr.startswith("_"):
|
||||
super().__setattr__(attr, value)
|
||||
else:
|
||||
setattr(self._tokenizer, attr, value)
|
||||
|
||||
|
||||
def _match(a, b):
|
||||
if type(a) != type(b):
|
||||
return False
|
||||
if isinstance(a, dict):
|
||||
return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a)
|
||||
if isinstance(a, list):
|
||||
return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b))
|
||||
|
||||
return a == b
|
||||
|
||||
|
||||
def _is_spm_decoder(decoder):
|
||||
_target_description = {
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
||||
{"type": "ByteFallback"},
|
||||
{"type": "Fuse"},
|
||||
{"type": "Strip", "content": " ", "start": 1, "stop": 0},
|
||||
],
|
||||
}
|
||||
return _match(_target_description, decoder)
|
||||
|
||||
|
||||
def _is_spm_decoder_no_space(decoder):
|
||||
_target_description = {
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
||||
{"type": "ByteFallback"},
|
||||
{"type": "Fuse"},
|
||||
],
|
||||
}
|
||||
return _match(_target_description, decoder)
|
||||
|
||||
|
||||
def _is_bpe_decoder(decoder):
|
||||
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
||||
|
||||
|
||||
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||
detokenizer to use.
|
||||
|
||||
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
detokenizer_class = NaiveStreamingDetokenizer
|
||||
|
||||
tokenizer_file = model_path / "tokenizer.json"
|
||||
if tokenizer_file.exists():
|
||||
with open(tokenizer_file, "r", encoding="utf-8") as fid:
|
||||
tokenizer_content = json.load(fid)
|
||||
if "decoder" in tokenizer_content:
|
||||
if _is_spm_decoder(tokenizer_content["decoder"]):
|
||||
detokenizer_class = SPMStreamingDetokenizer
|
||||
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
|
||||
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
|
||||
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||
detokenizer_class = BPEStreamingDetokenizer
|
||||
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
return TokenizerWrapper(
|
||||
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
||||
detokenizer_class,
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
|
||||
removed_bos = sequence if sequence[0] != bos else sequence[1:]
|
||||
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
|
||||
@@ -1,2 +0,0 @@
|
||||
from .trainer import TrainingArgs, evaluate, train
|
||||
from .utils import linear_to_lora_layers
|
||||
@@ -1,273 +0,0 @@
|
||||
import itertools
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold a dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
text_key: str = "text",
|
||||
):
|
||||
self._data = [tokenizer.encode(d[text_key]) for d in data]
|
||||
for d in self._data:
|
||||
if d[-1] != tokenizer.eos_token_id:
|
||||
d.append(tokenizer.eos_token_id)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ChatDataset:
|
||||
"""
|
||||
A dataset for chat data in the format of {"messages": [...]}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
chat_key: str = "messages",
|
||||
mask_prompt: bool = False,
|
||||
):
|
||||
self._data = []
|
||||
for d in data:
|
||||
messages = d[chat_key]
|
||||
tools = d.get("tools", None)
|
||||
tokens = tokenizer.apply_chat_template(messages, tools=tools)
|
||||
if mask_prompt:
|
||||
messages = messages[:-1]
|
||||
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
|
||||
self._data.append((tokens, offset))
|
||||
else:
|
||||
self._data.append(tokens)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class CompletionsDataset:
|
||||
"""
|
||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||
or using user-provided keys for prompt and completion values
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str,
|
||||
completion_key: str,
|
||||
mask_prompt: bool,
|
||||
):
|
||||
self._data = []
|
||||
for d in data:
|
||||
tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[completion_key]},
|
||||
],
|
||||
)
|
||||
if mask_prompt:
|
||||
offset = len(
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": d[prompt_key]}]
|
||||
)
|
||||
)
|
||||
self._data.append((tokens, offset))
|
||||
else:
|
||||
self._data.append(tokens)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ConcatenatedDataset:
|
||||
def __init__(self, data: List[Any]):
|
||||
self._data = list(itertools.chain(*data))
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def create_dataset(
|
||||
data,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config,
|
||||
):
|
||||
mask_prompt = getattr(config, "mask_prompt", False)
|
||||
prompt_feature = getattr(config, "prompt_feature", "prompt")
|
||||
text_feature = getattr(config, "text_feature", "text")
|
||||
completion_feature = getattr(config, "completion_feature", "completion")
|
||||
chat_feature = getattr(config, "chat_feature", "messages")
|
||||
sample = data[0]
|
||||
if prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(
|
||||
data, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||
)
|
||||
elif chat_feature in sample:
|
||||
return ChatDataset(
|
||||
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
|
||||
)
|
||||
elif text_feature in sample:
|
||||
if mask_prompt:
|
||||
raise ValueError("Prompt masking not supported for text dataset.")
|
||||
return Dataset(data, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||
)
|
||||
|
||||
|
||||
def load_local_dataset(
|
||||
data_path: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config,
|
||||
):
|
||||
def load_subset(path):
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
data = [json.loads(l) for l in fid]
|
||||
return create_dataset(data, tokenizer, config)
|
||||
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def load_hf_dataset(
|
||||
data_id: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config,
|
||||
):
|
||||
from datasets import exceptions, load_dataset
|
||||
|
||||
try:
|
||||
dataset = load_dataset(data_id)
|
||||
|
||||
names = ("train", "valid", "test")
|
||||
|
||||
train, valid, test = [
|
||||
(
|
||||
create_dataset(dataset[n], tokenizer, config)
|
||||
if n in dataset.keys()
|
||||
else []
|
||||
)
|
||||
for n in names
|
||||
]
|
||||
|
||||
except exceptions.DatasetNotFoundError:
|
||||
raise ValueError(f"Not found Hugging Face dataset: {data_id} .")
|
||||
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
import datasets
|
||||
|
||||
def create_hf_dataset(dataset_name, config, split, hf_config):
|
||||
ds = datasets.load_dataset(
|
||||
dataset_name,
|
||||
split=split,
|
||||
**hf_config,
|
||||
)
|
||||
return create_dataset(ds, tokenizer, config)
|
||||
|
||||
dataset_collection = args.hf_dataset
|
||||
if isinstance(dataset_collection, dict):
|
||||
dataset_collection = [dataset_collection]
|
||||
|
||||
collection = []
|
||||
for ds in dataset_collection:
|
||||
ds_name = ds["name"]
|
||||
print(f"Loading Hugging Face dataset {ds_name}.")
|
||||
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
|
||||
config = types.SimpleNamespace(**ds)
|
||||
hf_config = ds.get("config", {})
|
||||
if args.train:
|
||||
train_split = ds.get("train_split", "train[:80%]")
|
||||
valid_split = ds.get("valid_split", "train[-10%:]")
|
||||
train = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
train_split,
|
||||
hf_config,
|
||||
)
|
||||
valid = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
valid_split,
|
||||
hf_config,
|
||||
)
|
||||
else:
|
||||
train, valid = [], []
|
||||
|
||||
if args.test:
|
||||
test_split = ds.get("test_split")
|
||||
test = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
test_split,
|
||||
hf_config,
|
||||
)
|
||||
else:
|
||||
test = []
|
||||
|
||||
collection.append((train, valid, test))
|
||||
|
||||
if len(collection) == 1:
|
||||
return collection[0]
|
||||
|
||||
# Otherwise concatenate them
|
||||
return tuple(map(ConcatenatedDataset, zip(*collection)))
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if getattr(args, "hf_dataset", False):
|
||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||
else:
|
||||
data_path = Path(args.data)
|
||||
if data_path.exists():
|
||||
train, valid, test = load_local_dataset(data_path, tokenizer, args)
|
||||
else:
|
||||
print(f"Loading Hugging Face dataset {args.data}.")
|
||||
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||
)
|
||||
if args.train and len(valid) == 0:
|
||||
raise ValueError(
|
||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||
)
|
||||
if args.test and len(test) == 0:
|
||||
raise ValueError(
|
||||
"Test set not found or empty. Must provide test set for evaluation."
|
||||
)
|
||||
return train, valid, test
|
||||
@@ -1,228 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class DoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_base(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
input_dims *= 32 // linear.bits
|
||||
dora_lin = DoRALinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
dora_lin.set_linear(linear)
|
||||
return dora_lin
|
||||
|
||||
def fuse(self, de_quantize: bool = False):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = self._dequantized_weight()
|
||||
|
||||
# Use the same type as the linear weight
|
||||
dtype = weight.dtype
|
||||
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=False)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
weight = weight + lora_b @ lora_a
|
||||
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
|
||||
fused_linear.weight = norm_scale[:, None] * weight
|
||||
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
if self._is_quantized() and not de_quantize:
|
||||
fused_linear = nn.QuantizedLinear.from_linear(
|
||||
fused_linear,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.set_linear(nn.Linear(input_dims, output_dims, bias=bias))
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||
|
||||
def set_linear(self, linear):
|
||||
"""
|
||||
Set the self.linear layer and recompute self.m.
|
||||
"""
|
||||
self.linear = linear
|
||||
self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
|
||||
|
||||
def _dequantized_weight(self):
|
||||
"""
|
||||
Return the weight of linear layer and dequantize it if is quantized
|
||||
"""
|
||||
weight = self.linear.weight
|
||||
if self._is_quantized():
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
self.linear.scales,
|
||||
self.linear.biases,
|
||||
self.linear.group_size,
|
||||
self.linear.bits,
|
||||
)
|
||||
return weight
|
||||
|
||||
def _is_quantized(self):
|
||||
return isinstance(self.linear, nn.QuantizedLinear)
|
||||
|
||||
def __call__(self, x):
|
||||
# Regular LoRA (without a bias)
|
||||
w = self._dequantized_weight()
|
||||
y = x @ w.T
|
||||
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
out = y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
# Compute the norm of the adapted weights
|
||||
adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||
|
||||
# Remove the norm and scale by the learned magnitude
|
||||
out = (self.m / denom).astype(x.dtype) * out
|
||||
|
||||
if "bias" in self.linear:
|
||||
out = out + self.linear.bias
|
||||
return out
|
||||
|
||||
|
||||
class DoRAEmbedding(nn.Module):
|
||||
def from_base(
|
||||
embedding: nn.Embedding,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
num_embeddings, dims = embedding.weight.shape
|
||||
|
||||
# TODO support quantized weights in DoRALinear
|
||||
if isinstance(embedding, nn.QuantizedLinear):
|
||||
raise ValueError("DoRAEmbedding does not yet support quantization.")
|
||||
dora_embedding = DoRAEmbedding(
|
||||
num_embeddings=num_embeddings,
|
||||
dims=dims,
|
||||
r=r,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
dora_embedding.set_embedding(embedding)
|
||||
return dora_embedding
|
||||
|
||||
def fuse(self, de_quantize: bool = False):
|
||||
embedding = self.embedding
|
||||
weight = embedding.weight
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
num_embeddings, dims = weight.shape
|
||||
fused_embedding = nn.Embedding(num_embeddings, dims)
|
||||
|
||||
lora_a = (self.scale * self.lora_a).astype(dtype)
|
||||
lora_b = self.lora_b.astype(dtype)
|
||||
weight = weight + lora_a @ lora_b
|
||||
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
|
||||
fused_embedding.weight = norm_scale[:, None] * weight
|
||||
|
||||
return fused_embedding
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
dims: int,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular embedding layer weights
|
||||
self.set_embedding(nn.Embedding(num_embeddings, dims))
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(num_embeddings)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_embeddings, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, dims))
|
||||
|
||||
def set_embedding(self, embedding: nn.Module):
|
||||
self.embedding = embedding
|
||||
self.m = mx.linalg.norm(embedding.weight, axis=1)
|
||||
|
||||
def __call__(self, x):
|
||||
y = self.embedding(x)
|
||||
z = self.scale * self.lora_a[x] @ self.lora_b
|
||||
out = y + self.dropout(z).astype(y.dtype)
|
||||
|
||||
# Compute the norm of the adapted weights for the individual embeddings
|
||||
adapted = y + z
|
||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=-1))
|
||||
|
||||
# Remove the norm and scale by the learned magnitude
|
||||
out = (self.m[x] / denom)[..., None] * out
|
||||
|
||||
return out
|
||||
|
||||
def as_linear(self, x):
|
||||
y = self.embedding.as_linear(x)
|
||||
z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
|
||||
out = y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
# Compute the norm of the adapted weights
|
||||
adapted = self.embedding.weight + (self.scale * self.lora_a) @ self.lora_b
|
||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||
|
||||
# Remove the norm and scale by the learned magnitude
|
||||
out = (self.m / denom) * out
|
||||
|
||||
return out
|
||||
@@ -1,285 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_base(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
input_dims *= 32 // linear.bits
|
||||
lora_lin = LoRALinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def fuse(self, de_quantize: bool = False):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
is_quantized = isinstance(linear, nn.QuantizedLinear)
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
if is_quantized:
|
||||
dtype = linear.scales.dtype
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
linear.scales,
|
||||
linear.biases,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
fused_linear.weight = weight + lora_b @ lora_a
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
if is_quantized and not de_quantize:
|
||||
fused_linear = nn.QuantizedLinear.from_linear(
|
||||
fused_linear,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
y = self.linear(x)
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
return y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
|
||||
class LoRASwitchLinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_base(
|
||||
linear: nn.Module,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
lora_lin = LoRASwitchLinear(
|
||||
input_dims=linear.input_dims,
|
||||
output_dims=linear.output_dims,
|
||||
num_experts=linear.num_experts,
|
||||
r=r,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def fuse(self, de_quantize: bool = False):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
is_quantized = isinstance(linear, QuantizedSwitchLinear)
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
if is_quantized:
|
||||
dtype = mx.float16
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
linear.scales,
|
||||
linear.biases,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
num_experts, output_dims, input_dims = weight.shape
|
||||
fused_linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b).astype(dtype)
|
||||
lora_a = self.lora_a.reshape(num_experts, -1, input_dims).astype(dtype)
|
||||
fused_linear.weight = weight + lora_b @ lora_a
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
if is_quantized and not de_quantize:
|
||||
fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits)
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(r * num_experts, input_dims),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(num_experts, output_dims, r))
|
||||
self.num_experts = num_experts
|
||||
|
||||
def __call__(self, x, indices):
|
||||
shape = x.shape[:-3] + (self.num_experts, -1)
|
||||
|
||||
y = self.linear(x, indices)
|
||||
z = (self.dropout(x) @ self.lora_a.T).reshape(shape)
|
||||
z = mx.take_along_axis(z, indices[..., None], axis=-2)
|
||||
z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1)
|
||||
|
||||
return y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
|
||||
class LoRAEmbedding(nn.Module):
|
||||
@staticmethod
|
||||
def from_base(
|
||||
embedding: nn.Embedding,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
num_embeddings, dims = embedding.weight.shape
|
||||
if isinstance(embedding, nn.QuantizedEmbedding):
|
||||
dims *= 32 // embedding.bits
|
||||
lora_embedding = LoRAEmbedding(
|
||||
num_embeddings=num_embeddings,
|
||||
dims=dims,
|
||||
r=r,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
lora_embedding.embedding = embedding
|
||||
return lora_embedding
|
||||
|
||||
def fuse(self, de_quantize: bool = False):
|
||||
embedding = self.embedding
|
||||
weight = embedding.weight
|
||||
is_quantized = isinstance(embedding, nn.QuantizedEmbedding)
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
if is_quantized:
|
||||
dtype = embedding.scales.dtype
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
embedding.scales,
|
||||
embedding.biases,
|
||||
embedding.group_size,
|
||||
embedding.bits,
|
||||
)
|
||||
num_embeddings, dims = weight.shape
|
||||
fused_embedding = nn.Embedding(num_embeddings, dims)
|
||||
|
||||
lora_a = (self.scale * self.lora_a).astype(dtype)
|
||||
lora_b = self.lora_b.astype(dtype)
|
||||
fused_embedding.weight = weight + lora_a @ lora_b
|
||||
|
||||
if is_quantized and not de_quantize:
|
||||
fused_embedding = nn.QuantizedEmbedding.from_embedding(
|
||||
fused_embedding,
|
||||
embedding.group_size,
|
||||
embedding.bits,
|
||||
)
|
||||
|
||||
return fused_embedding
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
dims: int,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular embedding layer
|
||||
self.embedding = nn.Embedding(num_embeddings, dims)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(num_embeddings)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_embeddings, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, dims))
|
||||
|
||||
def __call__(self, x):
|
||||
y = self.embedding(x)
|
||||
z = self.dropout(self.lora_a[x] @ self.lora_b)
|
||||
out = y + (self.scale * z).astype(y.dtype)
|
||||
return out
|
||||
|
||||
def as_linear(self, x):
|
||||
y = self.embedding.as_linear(x)
|
||||
z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
|
||||
return y + (self.scale * z).astype(x.dtype)
|
||||
@@ -1,343 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import glob
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .datasets import CompletionsDataset
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
"""
|
||||
Update all instances of type(layer) to use gradient checkpointing.
|
||||
"""
|
||||
fn = type(layer).__call__
|
||||
|
||||
def checkpointed_fn(model, *args, **kwargs):
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
model.update(params)
|
||||
return fn(model, *args, **kwargs)
|
||||
|
||||
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
type(layer).__call__ = checkpointed_fn
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
|
||||
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
|
||||
val_batches: int = field(
|
||||
default=25,
|
||||
metadata={
|
||||
"help": "Number of validation batches, -1 uses the entire validation set."
|
||||
},
|
||||
)
|
||||
steps_per_report: int = field(
|
||||
default=10,
|
||||
metadata={"help": "Number of training steps between loss reporting."},
|
||||
)
|
||||
steps_per_eval: int = field(
|
||||
default=200, metadata={"help": "Number of training steps between validations."}
|
||||
)
|
||||
steps_per_save: int = field(
|
||||
default=100, metadata={"help": "Save the model every number steps"}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048, metadata={"help": "Maximum sequence length."}
|
||||
)
|
||||
adapter_file: str = field(
|
||||
default="adapters.safetensors",
|
||||
metadata={"help": "Save/load path for the trained adapter weights."},
|
||||
)
|
||||
grad_checkpoint: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, batch, lengths):
|
||||
inputs = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
steps = mx.arange(1, targets.shape[1] + 1)
|
||||
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||
ntoks = mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
max_seq_length,
|
||||
train=False,
|
||||
):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size}"
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# If running in distributed mode (N machines) then each one should skip N-1
|
||||
# samples
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
# Make the batches:
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
if len(batch[0]) == 2:
|
||||
batch, offsets = zip(*batch)
|
||||
else:
|
||||
offsets = [0] * len(batch)
|
||||
lengths = [len(x) for x in batch]
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
|
||||
"Consider pre-splitting your data to save memory."
|
||||
)
|
||||
|
||||
# Pad to the nearest multiple of 8 or the maximum length
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
|
||||
for j in range(batch_size // step):
|
||||
truncated_length = min(lengths[j], max_seq_length)
|
||||
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||
lengths[j] = (
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
batch = mx.array(batch_arr)
|
||||
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def evaluate(
|
||||
model,
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
num_batches,
|
||||
max_seq_length=2048,
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
):
|
||||
all_losses = mx.array(0.0)
|
||||
ntokens = mx.array(0)
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
for _, batch in zip(
|
||||
index_iterator,
|
||||
iterate_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
all_losses += losses * toks
|
||||
ntokens += toks
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||
|
||||
return (all_losses / ntokens).item()
|
||||
|
||||
|
||||
class TrainingCallback:
|
||||
|
||||
def on_train_loss_report(self, train_info: dict):
|
||||
"""Called to report training loss at specified intervals."""
|
||||
pass
|
||||
|
||||
def on_val_loss_report(self, val_info: dict):
|
||||
"""Called to report validation loss at specified intervals or the beginning."""
|
||||
pass
|
||||
|
||||
|
||||
def train(
|
||||
model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
args: TrainingArgs = TrainingArgs(),
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
print(f"Starting training..., iters: {args.iters}")
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
rank = world.rank()
|
||||
if world_size > 1:
|
||||
print(f"Node {rank} of {world_size}")
|
||||
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
# Forward and backward pass
|
||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||
|
||||
# All reduce the gradients if running in distributed mode
|
||||
grad = average_gradients(grad)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
return lvalue, toks
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
train_time = 0
|
||||
# Main training loop
|
||||
for it, batch in zip(
|
||||
range(1, args.iters + 1),
|
||||
iterate_batches(
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
tic = time.perf_counter()
|
||||
# Report validation loss if needed, the first validation loss
|
||||
# is always measured before any training.
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
tic = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss=loss,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - tic
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val took {val_time:.3f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
val_info = {
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
"val_time": val_time,
|
||||
}
|
||||
training_callback.on_val_loss_report(val_info)
|
||||
|
||||
tic = time.perf_counter()
|
||||
|
||||
lvalue, toks = step(batch)
|
||||
losses += lvalue
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, n_tokens)
|
||||
train_time += time.perf_counter() - tic
|
||||
|
||||
# Report training loss if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / train_time
|
||||
tokens_sec = float(n_tokens) / train_time
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
f"Trained Tokens {trained_tokens}, "
|
||||
f"Peak mem {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
train_info = {
|
||||
"iteration": it,
|
||||
"train_loss": train_loss,
|
||||
"learning_rate": learning_rate,
|
||||
"iterations_per_second": it_sec,
|
||||
"tokens_per_second": tokens_sec,
|
||||
"trained_tokens": trained_tokens,
|
||||
"peak_memory": peak_mem,
|
||||
}
|
||||
training_callback.on_train_loss_report(train_info)
|
||||
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
train_time = 0
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
checkpoint = (
|
||||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||
)
|
||||
mx.save_safetensors(str(checkpoint), adapter_weights)
|
||||
print(
|
||||
f"Iter {it}: Saved adapter weights to "
|
||||
f"{args.adapter_file} and {checkpoint}."
|
||||
)
|
||||
|
||||
# Save final weights
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
print(f"Saved final weights to {args.adapter_file}.")
|
||||
@@ -1,276 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
from .dora import DoRAEmbedding, DoRALinear
|
||||
from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
||||
|
||||
|
||||
def build_schedule(schedule_config: Dict):
|
||||
"""
|
||||
Build a learning rate schedule from the given config.
|
||||
"""
|
||||
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
|
||||
arguments = schedule_config["arguments"]
|
||||
initial_lr = arguments[0]
|
||||
bound_schedule_fn = schedule_fn(*arguments)
|
||||
if warmup_steps := schedule_config.get("warmup", 0):
|
||||
warmup_init = schedule_config.get("warmup_init", 0.0)
|
||||
warmup_fn = opt.schedulers.linear_schedule(
|
||||
warmup_init, initial_lr, warmup_steps
|
||||
)
|
||||
return opt.schedulers.join_schedules(
|
||||
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
|
||||
)
|
||||
else:
|
||||
return bound_schedule_fn
|
||||
|
||||
|
||||
def linear_to_lora_layers(
|
||||
model: nn.Module,
|
||||
num_layers: int,
|
||||
config: Dict,
|
||||
use_dora: bool = False,
|
||||
):
|
||||
"""
|
||||
Convert some of the models linear layers to lora layers.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The neural network model.
|
||||
num_layers (int): The number of blocks to convert to lora layers
|
||||
starting from the last layer.
|
||||
config (dict): More configuration parameters for LoRA, including the
|
||||
rank, scale, and optional layer keys.
|
||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||
Default: ``False``
|
||||
"""
|
||||
|
||||
def to_lora(layer):
|
||||
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
||||
LoRALayer = DoRALinear if use_dora else LoRALinear
|
||||
elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
|
||||
if use_dora:
|
||||
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
|
||||
LoRALayer = LoRASwitchLinear
|
||||
elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
|
||||
LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can't convert layer of type {type(layer).__name__} to LoRA"
|
||||
)
|
||||
|
||||
return LoRALayer.from_base(
|
||||
layer,
|
||||
r=config["rank"],
|
||||
scale=config["scale"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
|
||||
keys = config.get("keys", None)
|
||||
if keys is not None:
|
||||
keys = set(keys)
|
||||
elif model.model_type in [
|
||||
"mistral",
|
||||
"llama",
|
||||
"phi",
|
||||
"mixtral",
|
||||
"nemotron",
|
||||
"stablelm",
|
||||
"hunyuan",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"phimoe",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"granite",
|
||||
"helium",
|
||||
"starcoder2",
|
||||
"cohere",
|
||||
"cohere2",
|
||||
"minicpm",
|
||||
"deepseek",
|
||||
"olmo2",
|
||||
"olmoe",
|
||||
"internlm3",
|
||||
]:
|
||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||
if model.model_type in ["mixtral", "phimoe"]:
|
||||
keys.add("block_sparse_moe.gate")
|
||||
if model.model_type == "qwen2_moe":
|
||||
keys.add("mlp.gate")
|
||||
keys.add("mlp.shared_expert_gate")
|
||||
if model.model_type == "olmoe":
|
||||
keys.add("mlp.gate")
|
||||
|
||||
elif model.model_type == "gpt_bigcode":
|
||||
keys = set(["attn.c_attn"])
|
||||
elif model.model_type == "gpt2":
|
||||
keys = set(["attn.c_attn"])
|
||||
elif model.model_type == "gpt_neox":
|
||||
keys = set(["attention.query_key_value"])
|
||||
elif model.model_type == "olmo":
|
||||
keys = set(["att_proj"])
|
||||
elif model.model_type == "openelm":
|
||||
keys = set(["attn.qkv_proj"])
|
||||
elif model.model_type == "phi3":
|
||||
keys = set(["self_attn.qkv_proj"])
|
||||
elif model.model_type == "phi-msft":
|
||||
keys = set(["mixer.Wqkv", "moe.gate"])
|
||||
elif model.model_type == "dbrx":
|
||||
keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
|
||||
elif model.model_type == "internlm2":
|
||||
keys = set(["attention.wqkv", "attention.wo"])
|
||||
elif model.model_type == "deepseek_v2":
|
||||
keys = set(
|
||||
[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.q_a_proj",
|
||||
"self_attn.q_b_proj",
|
||||
"self_attn.kv_a_proj_with_mqa",
|
||||
"self_attn.kv_b_proj",
|
||||
]
|
||||
)
|
||||
elif model.model_type == "mamba":
|
||||
keys = set(
|
||||
[
|
||||
"mixer.in_proj",
|
||||
"mixer.x_proj",
|
||||
"mixer.dt_proj",
|
||||
"mixer.out_proj",
|
||||
]
|
||||
)
|
||||
elif model.model_type == "exaone":
|
||||
keys = set(["attn.attention.q_proj", "attn.attention.v_proj"])
|
||||
else:
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
for l in model.layers[-max(num_layers, 0) :]:
|
||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||
if lora_layers:
|
||||
l.update_modules(tree_unflatten(lora_layers))
|
||||
|
||||
lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
|
||||
if lora_modules:
|
||||
model.update_modules(tree_unflatten(lora_modules))
|
||||
|
||||
|
||||
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
"""
|
||||
Load any fine-tuned adapters / layers.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The neural network model.
|
||||
adapter_path (str): Path to the adapter configuration file.
|
||||
|
||||
Returns:
|
||||
nn.Module: The updated model with LoRA layers applied.
|
||||
"""
|
||||
adapter_path = Path(adapter_path)
|
||||
if not adapter_path.exists():
|
||||
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
||||
with open(adapter_path / "adapter_config.json", "r") as fid:
|
||||
config = types.SimpleNamespace(**json.load(fid))
|
||||
fine_tune_type = getattr(config, "fine_tune_type", "lora")
|
||||
if fine_tune_type != "full":
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
config.num_layers,
|
||||
config.lora_parameters,
|
||||
use_dora=(fine_tune_type == "dora"),
|
||||
)
|
||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def dequantize(model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Dequantize the quantized linear layers in the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model with quantized linear layers.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model with dequantized layers.
|
||||
"""
|
||||
de_quantize_layers = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.QuantizedLinear):
|
||||
bias = "bias" in module
|
||||
weight = module.weight
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
module.scales,
|
||||
module.biases,
|
||||
module.group_size,
|
||||
module.bits,
|
||||
).astype(mx.float16)
|
||||
output_dims, input_dims = weight.shape
|
||||
linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
linear.weight = weight
|
||||
if bias:
|
||||
linear.bias = module.bias
|
||||
de_quantize_layers.append((name, linear))
|
||||
if isinstance(module, nn.QuantizedEmbedding):
|
||||
weight = mx.dequantize(
|
||||
module.weight,
|
||||
module.scales,
|
||||
module.biases,
|
||||
module.group_size,
|
||||
module.bits,
|
||||
).astype(mx.float16)
|
||||
num_embeddings, dims = weight.shape
|
||||
emb = nn.Embedding(num_embeddings, dims)
|
||||
emb.weight = weight
|
||||
de_quantize_layers.append((name, emb))
|
||||
|
||||
if len(de_quantize_layers) > 0:
|
||||
model.update_modules(tree_unflatten(de_quantize_layers))
|
||||
return model
|
||||
|
||||
|
||||
def remove_lora_layers(model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Remove the LoRA layers from the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model with LoRA layers.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model without LoRA layers.
|
||||
"""
|
||||
reset_layers = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
reset_layers.append((name, module.linear))
|
||||
if len(reset_layers) > 0:
|
||||
model.update_modules(tree_unflatten(reset_layers))
|
||||
return model
|
||||
|
||||
|
||||
def nparams(module):
|
||||
if hasattr(module, "bits"):
|
||||
n = 0 if not hasattr(module, "bias") else module.bias.size
|
||||
return n + module.weight.size * 32 // module.bits
|
||||
return sum(v.size for _, v in tree_flatten(module.parameters()))
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
leaf_modules = tree_flatten(
|
||||
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
|
||||
)
|
||||
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
|
||||
trainable_p = (
|
||||
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
|
||||
)
|
||||
print(
|
||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||
)
|
||||
1118
llms/mlx_lm/utils.py
1118
llms/mlx_lm/utils.py
File diff suppressed because it is too large
Load Diff
@@ -1,47 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
package_dir = Path(__file__).parent / "mlx_lm"
|
||||
with open(package_dir / "requirements.txt") as fid:
|
||||
requirements = [l.strip() for l in fid.readlines()]
|
||||
|
||||
sys.path.append(str(package_dir))
|
||||
from _version import __version__
|
||||
|
||||
setup(
|
||||
name="mlx-lm",
|
||||
version=__version__,
|
||||
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
readme="README.md",
|
||||
author_email="mlx@group.apple.com",
|
||||
author="MLX Contributors",
|
||||
url="https://github.com/ml-explore/mlx-examples",
|
||||
license="MIT",
|
||||
install_requires=requirements,
|
||||
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
|
||||
python_requires=">=3.8",
|
||||
extras_require={
|
||||
"test": ["datasets"],
|
||||
"evaluate": ["lm-eval", "tqdm"],
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
|
||||
"mlx_lm.chat = mlx_lm.chat:main",
|
||||
"mlx_lm.convert = mlx_lm.convert:main",
|
||||
"mlx_lm.evaluate = mlx_lm.evaluate:main",
|
||||
"mlx_lm.fuse = mlx_lm.fuse:main",
|
||||
"mlx_lm.generate = mlx_lm.generate:main",
|
||||
"mlx_lm.lora = mlx_lm.lora:main",
|
||||
"mlx_lm.merge = mlx_lm.merge:main",
|
||||
"mlx_lm.server = mlx_lm.server:main",
|
||||
"mlx_lm.manage = mlx_lm.manage:main",
|
||||
]
|
||||
},
|
||||
)
|
||||
@@ -1,113 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from mlx_lm.tuner import datasets
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
|
||||
class TestDatasets(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir_fid.name)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def save_data(self, data):
|
||||
for ds in ["train", "valid"]:
|
||||
with open(os.path.join(self.test_dir, f"{ds}.jsonl"), "w") as fid:
|
||||
for l in data:
|
||||
json.dump(l, fid)
|
||||
fid.write("\n")
|
||||
|
||||
def test_text(self):
|
||||
data = {"text": "This is an example for the model."}
|
||||
self.save_data(4 * [data])
|
||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(len(train), 4)
|
||||
self.assertEqual(len(valid), 4)
|
||||
self.assertEqual(len(test), 0)
|
||||
self.assertTrue(len(train[0]) > 0)
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertTrue(isinstance(train, datasets.Dataset))
|
||||
|
||||
def test_completions(self):
|
||||
data = {"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||
self.save_data(4 * [data])
|
||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(len(train), 4)
|
||||
self.assertEqual(len(valid), 4)
|
||||
self.assertEqual(len(test), 0)
|
||||
self.assertTrue(len(train[0]) > 0)
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertTrue(isinstance(train, datasets.CompletionsDataset))
|
||||
|
||||
def test_chat(self):
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello."},
|
||||
{"role": "assistant", "content": "How can I assistant you today."},
|
||||
]
|
||||
}
|
||||
self.save_data(4 * [data])
|
||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(len(train), 4)
|
||||
self.assertEqual(len(valid), 4)
|
||||
self.assertEqual(len(test), 0)
|
||||
self.assertTrue(len(train[0]) > 0)
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||
|
||||
def test_hf(self):
|
||||
hf_args = {
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
}
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset=hf_args,
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertTrue(len(train) > 0)
|
||||
self.assertTrue(len(train[0]) > 0)
|
||||
self.assertTrue(len(valid) > 0)
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertEqual(len(test), 0)
|
||||
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset=[hf_args, hf_args],
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(2 * len(train), len(train_double))
|
||||
self.assertEqual(2 * len(valid), len(valid_double))
|
||||
self.assertEqual(2 * len(test), len(test_double))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,447 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import lora, tuner
|
||||
from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
|
||||
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
|
||||
from mlx_lm.tuner.trainer import evaluate
|
||||
from mlx_lm.tuner.utils import build_schedule
|
||||
|
||||
|
||||
@contextmanager
|
||||
def swapped_with_identity(obj, func):
|
||||
old_func = getattr(obj, func)
|
||||
setattr(obj, func, lambda x, **kwargs: x)
|
||||
yield
|
||||
setattr(obj, func, old_func)
|
||||
|
||||
|
||||
class TestLora(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.capturedOutput = StringIO()
|
||||
sys.stdout = self.capturedOutput
|
||||
|
||||
def tearDown(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
def test_llama(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
tie_word_embeddings=False,
|
||||
)
|
||||
|
||||
lora_layers = 4
|
||||
|
||||
def check_config(params, expected_trainable_parameters=None):
|
||||
n_keys = 2
|
||||
if "keys" in params:
|
||||
n_keys = len(params["keys"])
|
||||
model = llama.Model(args)
|
||||
model.freeze()
|
||||
tuner.utils.linear_to_lora_layers(model, lora_layers, params)
|
||||
trainable_params = sum(
|
||||
v.size for _, v in tree_flatten(model.trainable_parameters())
|
||||
)
|
||||
|
||||
expected_trainable_parameters = expected_trainable_parameters or (
|
||||
lora_layers * params["rank"] * args.hidden_size * 2 * n_keys
|
||||
)
|
||||
self.assertEqual(trainable_params, expected_trainable_parameters)
|
||||
|
||||
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||
check_config(params)
|
||||
|
||||
params["rank"] = 1
|
||||
check_config(params)
|
||||
|
||||
params["keys"] = ["self_attn.k_proj"]
|
||||
check_config(params)
|
||||
|
||||
params["keys"] = ["lm_head"]
|
||||
check_config(
|
||||
params,
|
||||
expected_trainable_parameters=(
|
||||
params["rank"] * (args.hidden_size + args.vocab_size)
|
||||
),
|
||||
)
|
||||
|
||||
params["keys"] = ["model.embed_tokens"]
|
||||
check_config(
|
||||
params,
|
||||
expected_trainable_parameters=(
|
||||
params["rank"] * (args.hidden_size + args.vocab_size)
|
||||
),
|
||||
)
|
||||
|
||||
def test_gpt_neox(self):
|
||||
from mlx_lm.models import gpt_neox
|
||||
|
||||
args = gpt_neox.ModelArgs(
|
||||
model_type="gpt_neox",
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=6144,
|
||||
num_attention_heads=64,
|
||||
num_hidden_layers=44,
|
||||
layer_norm_eps=1e-5,
|
||||
vocab_size=50432,
|
||||
rotary_emb_base=10_000,
|
||||
rotary_pct=0.25,
|
||||
)
|
||||
|
||||
num_lora_layers = 4
|
||||
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||
|
||||
model = gpt_neox.Model(args)
|
||||
model.freeze()
|
||||
tuner.utils.linear_to_lora_layers(model, num_lora_layers, params)
|
||||
|
||||
def test_lora_embedding(self):
|
||||
num_embeddings = 256
|
||||
dims = 512
|
||||
tokens = mx.array([1, 2, 3])
|
||||
|
||||
embedding = nn.QuantizedEmbedding(num_embeddings, dims)
|
||||
dequantized_weight = mx.dequantize(
|
||||
embedding.weight,
|
||||
embedding.scales,
|
||||
embedding.biases,
|
||||
embedding.group_size,
|
||||
embedding.bits,
|
||||
)
|
||||
lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
|
||||
new_embedding = lora_emb.fuse(de_quantize=True)
|
||||
self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight))
|
||||
self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens)))
|
||||
|
||||
# as_linear
|
||||
attn_output = mx.random.uniform(shape=(dims,))
|
||||
embedding_lin_out = lora_emb.as_linear(attn_output)
|
||||
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
|
||||
self.assertTrue(
|
||||
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
|
||||
)
|
||||
|
||||
# change the value of lora_b and the embeddings will no longer be equal
|
||||
lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape)
|
||||
new_embedding = lora_emb.fuse(de_quantize=True)
|
||||
self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight))
|
||||
self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens)))
|
||||
|
||||
|
||||
class TestDora(unittest.TestCase):
|
||||
def test_dora_embedding(self):
|
||||
num_embeddings = 256
|
||||
dims = 512
|
||||
tokens = mx.array([1, 2, 3])
|
||||
|
||||
embedding = nn.Embedding(num_embeddings, dims)
|
||||
|
||||
dora_emb = DoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
|
||||
new_embedding = dora_emb.fuse()
|
||||
self.assertTrue(mx.array_equal(embedding.weight, new_embedding.weight))
|
||||
self.assertTrue(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
||||
|
||||
# as_linear
|
||||
attn_output = mx.random.uniform(shape=(dims,))
|
||||
embedding_lin_out = dora_emb.as_linear(attn_output)
|
||||
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
|
||||
self.assertTrue(
|
||||
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
|
||||
)
|
||||
|
||||
# change the value of lora_b and the embeddings will no longer be equal
|
||||
dora_emb.lora_b = mx.random.uniform(shape=dora_emb.lora_b.shape)
|
||||
new_embedding = dora_emb.fuse()
|
||||
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
|
||||
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
||||
|
||||
def test_llama(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
hidden_size = 1024
|
||||
intermediate_size = 2048
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
|
||||
dora_layers = 4
|
||||
|
||||
def check_config(params):
|
||||
n_keys = 2
|
||||
if "keys" in params:
|
||||
n_keys = len(params["keys"])
|
||||
model = llama.Model(args)
|
||||
model.freeze()
|
||||
tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
|
||||
trainable_params = sum(
|
||||
v.size for _, v in tree_flatten(model.trainable_parameters())
|
||||
)
|
||||
self.assertEqual(
|
||||
trainable_params,
|
||||
dora_layers
|
||||
* (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
|
||||
)
|
||||
|
||||
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||
check_config(params)
|
||||
|
||||
params["rank"] = 1
|
||||
check_config(params)
|
||||
|
||||
params["keys"] = ["self_attn.k_proj"]
|
||||
check_config(params)
|
||||
|
||||
def test_dora_m_parameter(self):
|
||||
dora_lin = DoRALinear(input_dims=100, output_dims=100)
|
||||
self.assertTrue(
|
||||
mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
|
||||
)
|
||||
|
||||
# Recomputes m when changing Linear
|
||||
inital_m = dora_lin.m
|
||||
lin = nn.Linear(10, 10)
|
||||
dora_lin.set_linear(lin)
|
||||
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1)))
|
||||
|
||||
# Works with quantized weights
|
||||
quantized_linear = nn.QuantizedLinear(512, 512)
|
||||
dora_lin.set_linear(quantized_linear)
|
||||
dequantized_weight = mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||
)
|
||||
|
||||
def test_dora_from_linear(self):
|
||||
in_dims = 256
|
||||
out_dims = 256
|
||||
r = 4
|
||||
|
||||
linear = nn.Linear(in_dims, out_dims)
|
||||
dora_lin = DoRALinear.from_base(linear, r)
|
||||
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1)))
|
||||
self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
|
||||
self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
|
||||
self.assertEqual(dora_lin.m.shape, (out_dims,))
|
||||
|
||||
quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
|
||||
dequantized_weight = mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
dora_quant_lin = DoRALinear.from_base(quantized_linear, r)
|
||||
self.assertTrue(
|
||||
mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||
)
|
||||
self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
|
||||
self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
|
||||
self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
|
||||
|
||||
def test_dora_to_linear(self):
|
||||
in_dims = 256
|
||||
out_dims = 256
|
||||
r = 4
|
||||
|
||||
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||
dora_lin = DoRALinear.from_base(linear, r)
|
||||
to_linear = dora_lin.fuse()
|
||||
self.assertTrue(mx.allclose(linear.weight, to_linear.weight))
|
||||
self.assertTrue(mx.allclose(linear.bias, to_linear.bias))
|
||||
|
||||
def dequantize_weight(quantized_linear):
|
||||
return mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
|
||||
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
|
||||
dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
|
||||
# Dequantize
|
||||
to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
|
||||
self.assertTrue(
|
||||
mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
dequantize_weight(quantized_linear), to_linear_from_quantized.weight
|
||||
)
|
||||
)
|
||||
|
||||
def test_dora_dtype(self):
|
||||
in_dims = 256
|
||||
out_dims = 256
|
||||
r = 4
|
||||
|
||||
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||
linear.set_dtype(mx.float16)
|
||||
dora_lin = DoRALinear.from_base(linear, r)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 256)).astype(mx.float16)
|
||||
self.assertEqual(dora_lin(x).dtype, mx.float16)
|
||||
|
||||
|
||||
class TestScheduleConfig(unittest.TestCase):
|
||||
def test_join(self):
|
||||
config = {"name": "cosine_decay", "warmup": 100, "arguments": [1e-5, 100]}
|
||||
cos_with_warmup = build_schedule(config)
|
||||
self.assertIsNotNone(cos_with_warmup)
|
||||
|
||||
self.assertEqual(cos_with_warmup(0), 0.0)
|
||||
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
|
||||
optimizer = opt.Adam(learning_rate=cos_with_warmup)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
|
||||
|
||||
def test_single_schedule(self):
|
||||
|
||||
config = {
|
||||
"name": "cosine_decay",
|
||||
"arguments": [0.1, 10],
|
||||
}
|
||||
lr_schedule = build_schedule(config)
|
||||
lr = lr_schedule(4)
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_non_zero_warmup(self):
|
||||
config = {
|
||||
"name": "cosine_decay",
|
||||
"warmup": 10,
|
||||
"warmup_init": 1e-6,
|
||||
"arguments": [1e-5, 20],
|
||||
}
|
||||
lr_schedule = build_schedule(config)
|
||||
lr = lr_schedule(0)
|
||||
self.assertAlmostEqual(lr, 1e-6, delta=1e-7)
|
||||
|
||||
def test_malformed_config(self):
|
||||
config = {"warmup": 100}
|
||||
self.assertRaises(KeyError, build_schedule, config)
|
||||
|
||||
config = {"cosine_decay": None}
|
||||
self.assertRaises(KeyError, build_schedule, config)
|
||||
|
||||
def test_evaluate_calls(self):
|
||||
mock_model = MagicMock()
|
||||
mock_dataset = MagicMock()
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_default_loss = MagicMock()
|
||||
mock_iterate_batches = MagicMock()
|
||||
|
||||
mock_iterate_batches.return_value = [
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
]
|
||||
|
||||
mock_default_loss.side_effect = [
|
||||
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
|
||||
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
|
||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
||||
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
||||
]
|
||||
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||
evaluate(
|
||||
model=mock_model,
|
||||
dataset=mock_dataset,
|
||||
tokenizer=mock_tokenizer,
|
||||
batch_size=2,
|
||||
num_batches=2,
|
||||
max_seq_length=2048,
|
||||
loss=mock_default_loss,
|
||||
iterate_batches=mock_iterate_batches,
|
||||
)
|
||||
|
||||
mock_iterate_batches.assert_called_once_with(
|
||||
dataset=mock_dataset,
|
||||
tokenizer=mock_tokenizer,
|
||||
batch_size=2,
|
||||
max_seq_length=2048,
|
||||
)
|
||||
self.assertEqual(mock_default_loss.call_count, 2)
|
||||
|
||||
def test_evaluate_infinite_batches(self):
|
||||
mock_model = MagicMock()
|
||||
mock_dataset = MagicMock()
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_default_loss = MagicMock()
|
||||
mock_iterate_batches = MagicMock()
|
||||
|
||||
mock_iterate_batches.return_value = [
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
(MagicMock(), MagicMock()),
|
||||
]
|
||||
|
||||
mock_default_loss.side_effect = [
|
||||
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
|
||||
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
|
||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||
]
|
||||
|
||||
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||
evaluate(
|
||||
model=mock_model,
|
||||
dataset=mock_dataset,
|
||||
tokenizer=mock_tokenizer,
|
||||
batch_size=2,
|
||||
num_batches=-1,
|
||||
max_seq_length=2048,
|
||||
loss=mock_default_loss,
|
||||
iterate_batches=mock_iterate_batches,
|
||||
)
|
||||
|
||||
mock_iterate_batches.assert_called_once_with(
|
||||
dataset=mock_dataset,
|
||||
tokenizer=mock_tokenizer,
|
||||
batch_size=2,
|
||||
max_seq_length=2048,
|
||||
)
|
||||
self.assertEqual(mock_default_loss.call_count, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,91 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from mlx_lm.sample_utils import make_logits_processors
|
||||
from mlx_lm.utils import (
|
||||
GenerationResponse,
|
||||
generate,
|
||||
load,
|
||||
make_sampler,
|
||||
stream_generate,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerate(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
|
||||
|
||||
def test_generate(self):
|
||||
# Simple test that generation runs
|
||||
text = generate(
|
||||
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
|
||||
)
|
||||
|
||||
def test_generate_with_logit_bias(self):
|
||||
logit_bias = {0: 2000.0, 1: -20.0}
|
||||
text = generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
"hello",
|
||||
max_tokens=5,
|
||||
logits_processors=make_logits_processors(logit_bias),
|
||||
verbose=False,
|
||||
)
|
||||
self.assertEqual(text, "!!!!!")
|
||||
|
||||
def test_generate_with_processor(self):
|
||||
init_toks = self.tokenizer.encode("hello")
|
||||
|
||||
all_toks = None
|
||||
|
||||
def logits_processor(toks, logits):
|
||||
nonlocal all_toks
|
||||
all_toks = toks
|
||||
return logits
|
||||
|
||||
generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
"hello",
|
||||
max_tokens=5,
|
||||
verbose=False,
|
||||
logits_processors=[logits_processor],
|
||||
)
|
||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||
|
||||
def test_stream_generate_speculative(self):
|
||||
# Use same model as draft model, this is not a speed test
|
||||
draft_model, _ = load(self.HF_MODEL_PATH)
|
||||
|
||||
results: List[GenerationResponse] = []
|
||||
drafted: List[bool] = []
|
||||
|
||||
# make a determinate sampler
|
||||
sampler = make_sampler(temp=0.0)
|
||||
|
||||
for generation_result in stream_generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt="hello",
|
||||
max_tokens=5,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=2,
|
||||
sampler=sampler,
|
||||
):
|
||||
drafted.append(generation_result.from_draft)
|
||||
results.append(generation_result)
|
||||
|
||||
self.assertEqual(len(results), 5)
|
||||
# since num_draft_tokens is 2 and draft model is the same, the
|
||||
# first 2 generations should be drafts, the third should come
|
||||
# from the target model, and last two should be drafts
|
||||
self.assertEqual(drafted, [True, True, False, True, True])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,58 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.gguf import convert_to_gguf
|
||||
|
||||
|
||||
class TestConvertToGGUFWithoutMocks(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json")
|
||||
with open(cls.tokenizer_file_path, "w") as f:
|
||||
f.write("{}")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
@patch("transformers.AutoTokenizer.from_pretrained")
|
||||
@patch("mlx.core.save_gguf")
|
||||
def test_convert_to_gguf(
|
||||
self,
|
||||
mock_save_gguf,
|
||||
mock_from_pretrained,
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.vocab_size = 3
|
||||
mock_tokenizer.get_added_vocab.return_value = {}
|
||||
mock_tokenizer.get_vocab.return_value = {"<pad>": 0, "hello": 1, "world": 2}
|
||||
mock_tokenizer.all_special_tokens = ["<pad>"]
|
||||
mock_tokenizer.all_special_ids = [0]
|
||||
mock_from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
model_path = Path(self.test_dir)
|
||||
weights = {
|
||||
"self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]),
|
||||
}
|
||||
config = {
|
||||
"num_attention_heads": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"_name_or_path": "test-llama",
|
||||
}
|
||||
output_file_path = "/fake/output/path/gguf_model.gguf"
|
||||
|
||||
convert_to_gguf(model_path, weights, config, output_file_path)
|
||||
called_args, _ = mock_save_gguf.call_args
|
||||
self.assertEqual(called_args[0], output_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,986 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models import rope_utils
|
||||
from mlx_lm.models.base import create_causal_mask
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
|
||||
def test_kv_cache(self):
|
||||
cache = KVCache()
|
||||
|
||||
k = mx.ones((1, 4, 1, 32), mx.float16)
|
||||
v = mx.ones((1, 4, 1, 32), mx.float16)
|
||||
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up, k))
|
||||
self.assertTrue(mx.array_equal(v_up, v))
|
||||
self.assertEqual(cache.offset, 1)
|
||||
|
||||
k = mx.ones((1, 4, cache.step, 32), mx.float16)
|
||||
v = mx.ones((1, 4, cache.step, 32), mx.float16)
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
|
||||
expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16)
|
||||
self.assertTrue(mx.array_equal(k_up, expected))
|
||||
self.assertTrue(mx.array_equal(v_up, expected))
|
||||
self.assertEqual(cache.offset, cache.step + 1)
|
||||
|
||||
def test_rotating_kv_cache(self):
|
||||
b, h, d = 1, 2, 32
|
||||
cache = RotatingKVCache(max_size=8, step=4)
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 2, d))
|
||||
v = mx.random.uniform(shape=(b, h, 2, d))
|
||||
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up, k))
|
||||
self.assertTrue(mx.array_equal(v_up, v))
|
||||
self.assertEqual(cache.offset, 2)
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 5, d))
|
||||
v = mx.random.uniform(shape=(b, h, 5, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
|
||||
self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 4, d))
|
||||
v = mx.random.uniform(shape=(b, h, 4, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
|
||||
self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
|
||||
|
||||
idx = 0
|
||||
for _ in range(10):
|
||||
k = mx.random.uniform(shape=(b, h, 1, d))
|
||||
v = mx.random.uniform(shape=(b, h, 1, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
|
||||
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
|
||||
idx += 1
|
||||
idx %= 8
|
||||
|
||||
# Try with nonzero keep
|
||||
cache = RotatingKVCache(max_size=8, step=4, keep=2)
|
||||
|
||||
# Check a large update
|
||||
k = mx.random.uniform(shape=(b, h, 20, d))
|
||||
v = mx.random.uniform(shape=(b, h, 20, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up, k))
|
||||
self.assertTrue(mx.array_equal(v_up, v))
|
||||
|
||||
# A bunch of small updates
|
||||
self.assertEqual(cache.offset, 20)
|
||||
idx = 2
|
||||
for i in range(10):
|
||||
k = mx.random.uniform(shape=(b, h, 1, d))
|
||||
v = mx.random.uniform(shape=(b, h, 1, d))
|
||||
k_up, v_up = cache.update_and_fetch(k, v)
|
||||
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
|
||||
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
|
||||
self.assertEqual(cache.offset, 21 + i)
|
||||
idx += 1
|
||||
if idx >= 8:
|
||||
idx = 2
|
||||
|
||||
def test_rotating_kv_cache_chat_mode(self):
|
||||
# Test that the rotating kv cache can handle
|
||||
# alternating prompt/prefill with generation
|
||||
d = 4
|
||||
h = 2
|
||||
cache = RotatingKVCache(max_size=18, step=4)
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 8, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 8)
|
||||
self.assertEqual(cache.offset, 8)
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 1, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 9)
|
||||
self.assertEqual(cache.offset, 9)
|
||||
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 11)
|
||||
self.assertEqual(cache.offset, 11)
|
||||
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 3, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 14)
|
||||
self.assertEqual(cache.offset, 14)
|
||||
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 6, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(cache.offset, 20)
|
||||
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(cache.offset, 22)
|
||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||
|
||||
def test_causal_mask_lengths(self):
|
||||
mx.random.seed(8)
|
||||
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
||||
lengths = mx.array([1, 2, 3, 1])
|
||||
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
||||
v = k
|
||||
mask = create_causal_mask(T_q, 0, lengths=lengths)
|
||||
|
||||
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
||||
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
|
||||
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
|
||||
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
|
||||
|
||||
def test_rope(self):
|
||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||
|
||||
rope = rope_utils.initialize_rope(
|
||||
32,
|
||||
base=100,
|
||||
traditional=False,
|
||||
scaling_config={"rope_type": "linear", "factor": 10.0},
|
||||
)
|
||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||
|
||||
rope = rope_utils.initialize_rope(
|
||||
32,
|
||||
base=100,
|
||||
traditional=False,
|
||||
scaling_config={"rope_type": "llama3", "factor": 2.0},
|
||||
)
|
||||
self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE))
|
||||
|
||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||
|
||||
self.assertEqual(len(model.layers), num_layers)
|
||||
self.assertEqual(model.model_type, model_type)
|
||||
|
||||
for t in [mx.float32, mx.float16]:
|
||||
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
|
||||
|
||||
inputs = mx.array([[0, 1]])
|
||||
outputs = model(inputs)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
cache = make_prompt_cache(model)
|
||||
outputs = model(inputs, cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
if model_type not in ("mamba", "plamo2"):
|
||||
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||
outputs = model(inputs, mask=mask)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
def test_llama(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = llama.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phi2(self):
|
||||
from mlx_lm.models import phi
|
||||
|
||||
args = phi.ModelArgs()
|
||||
model = phi.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phixtral(self):
|
||||
from mlx_lm.models import phixtral
|
||||
|
||||
args = phixtral.ModelArgs(
|
||||
"phixtral", num_vocab=1000, num_layers=4, model_dim=1024
|
||||
)
|
||||
model = phixtral.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers)
|
||||
|
||||
def test_phi3(self):
|
||||
from mlx_lm.models import phi3
|
||||
|
||||
args = phi3.ModelArgs(
|
||||
model_type="phi3",
|
||||
hidden_size=3072,
|
||||
num_hidden_layers=32,
|
||||
intermediate_size=8192,
|
||||
num_attention_heads=32,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=32064,
|
||||
)
|
||||
model = phi3.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma(self):
|
||||
from mlx_lm.models import gemma
|
||||
|
||||
args = gemma.ModelArgs(
|
||||
model_type="gemma",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
head_dim=128,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
model = gemma.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_mixtral(self):
|
||||
from mlx_lm.models import mixtral
|
||||
|
||||
# Make a baby mixtral, because it will actually do the
|
||||
# eval
|
||||
args = mixtral.ModelArgs(
|
||||
model_type="mixtral",
|
||||
vocab_size=100,
|
||||
hidden_size=32,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_experts_per_tok=2,
|
||||
num_key_value_heads=2,
|
||||
num_local_experts=4,
|
||||
)
|
||||
model = mixtral.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
@unittest.skip("requires ai2-olmo")
|
||||
def test_olmo(self):
|
||||
from mlx_lm.models import olmo
|
||||
|
||||
args = olmo.ModelArgs(
|
||||
model_type="olmo",
|
||||
d_model=1024,
|
||||
n_layers=4,
|
||||
mlp_hidden_size=2048,
|
||||
n_heads=2,
|
||||
vocab_size=10_000,
|
||||
embedding_size=10_000,
|
||||
)
|
||||
model = olmo.Model(args)
|
||||
self.model_test_runner(
|
||||
model,
|
||||
args.model_type,
|
||||
args.vocab_size,
|
||||
args.n_layers,
|
||||
)
|
||||
|
||||
def test_qwen2_moe(self):
|
||||
from mlx_lm.models import qwen2_moe
|
||||
|
||||
args = qwen2_moe.ModelArgs(
|
||||
model_type="qwen2_moe",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
num_experts_per_tok=4,
|
||||
num_experts=16,
|
||||
moe_intermediate_size=1024,
|
||||
shared_expert_intermediate_size=2048,
|
||||
)
|
||||
model = qwen2_moe.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_qwen2(self):
|
||||
from mlx_lm.models import qwen2
|
||||
|
||||
args = qwen2.ModelArgs(
|
||||
model_type="qwen2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = qwen2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_qwen(self):
|
||||
from mlx_lm.models import qwen
|
||||
|
||||
args = qwen.ModelArgs(
|
||||
model_type="qwen",
|
||||
)
|
||||
model = qwen.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_plamo(self):
|
||||
from mlx_lm.models import plamo
|
||||
|
||||
args = plamo.ModelArgs(
|
||||
model_type="plamo",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=8,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = plamo.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_plamo2(self):
|
||||
from mlx_lm.models import plamo2
|
||||
|
||||
args = plamo2.ModelArgs(
|
||||
model_type="plamo2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=8,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = plamo2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_stablelm(self):
|
||||
from mlx_lm.models import stablelm
|
||||
|
||||
args = stablelm.ModelArgs(
|
||||
model_type="stablelm",
|
||||
vocab_size=10_000,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
partial_rotary_factor=1.0,
|
||||
intermediate_size=2048,
|
||||
layer_norm_eps=1e-2,
|
||||
rope_theta=10_000,
|
||||
use_qkv_bias=False,
|
||||
)
|
||||
model = stablelm.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
# StableLM 2
|
||||
args = stablelm.ModelArgs(
|
||||
model_type="stablelm",
|
||||
vocab_size=10000,
|
||||
hidden_size=512,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
partial_rotary_factor=0.25,
|
||||
intermediate_size=1024,
|
||||
layer_norm_eps=1e-5,
|
||||
rope_theta=10000,
|
||||
use_qkv_bias=True,
|
||||
)
|
||||
model = stablelm.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_starcoder2(self):
|
||||
from mlx_lm.models import starcoder2
|
||||
|
||||
args = starcoder2.ModelArgs(
|
||||
model_type="starcoder2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
model = starcoder2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_cohere(self):
|
||||
from mlx_lm.models import cohere
|
||||
|
||||
args = cohere.ModelArgs(
|
||||
model_type="cohere",
|
||||
)
|
||||
model = cohere.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_dbrx(self):
|
||||
from mlx_lm.models import dbrx
|
||||
|
||||
args = dbrx.ModelArgs(
|
||||
model_type="dbrx",
|
||||
d_model=1024,
|
||||
ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2},
|
||||
attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000},
|
||||
n_layers=4,
|
||||
n_heads=4,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = dbrx.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers)
|
||||
|
||||
def test_minicpm(self):
|
||||
from mlx_lm.models import minicpm
|
||||
|
||||
args = minicpm.ModelArgs(
|
||||
model_type="minicpm",
|
||||
hidden_size=1024,
|
||||
dim_model_base=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-4,
|
||||
vocab_size=10000,
|
||||
num_key_value_heads=2,
|
||||
scale_depth=1.0,
|
||||
scale_emb=1.0,
|
||||
)
|
||||
model = minicpm.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_mamba(self):
|
||||
from mlx_lm.models import mamba
|
||||
|
||||
args = mamba.ModelArgs(
|
||||
model_type="mamba",
|
||||
vocab_size=10000,
|
||||
use_bias=False,
|
||||
use_conv_bias=True,
|
||||
conv_kernel=4,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=24,
|
||||
state_size=16,
|
||||
intermediate_size=1536,
|
||||
time_step_rank=48,
|
||||
)
|
||||
model = mamba.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gpt2(self):
|
||||
from mlx_lm.models import gpt2
|
||||
|
||||
args = gpt2.ModelArgs(
|
||||
model_type="gpt2",
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_head=12,
|
||||
n_layer=12,
|
||||
n_positions=1024,
|
||||
layer_norm_epsilon=1e-5,
|
||||
vocab_size=50256,
|
||||
)
|
||||
model = gpt2.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
||||
|
||||
def test_gpt_neox(self):
|
||||
from mlx_lm.models import gpt_neox
|
||||
|
||||
args = gpt_neox.ModelArgs(
|
||||
model_type="gpt_neox",
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=6144,
|
||||
num_attention_heads=64,
|
||||
num_hidden_layers=44,
|
||||
layer_norm_eps=1e-5,
|
||||
vocab_size=50432,
|
||||
rotary_emb_base=10_000,
|
||||
rotary_pct=0.25,
|
||||
)
|
||||
model = gpt_neox.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_openelm(self):
|
||||
from mlx_lm.models import openelm
|
||||
|
||||
args = openelm.ModelArgs(
|
||||
model_type="openelm",
|
||||
ffn_dim_divisor=256,
|
||||
ffn_multipliers=[
|
||||
0.5,
|
||||
0.73,
|
||||
0.97,
|
||||
1.2,
|
||||
1.43,
|
||||
1.67,
|
||||
1.9,
|
||||
2.13,
|
||||
2.37,
|
||||
2.6,
|
||||
2.83,
|
||||
3.07,
|
||||
3.3,
|
||||
3.53,
|
||||
3.77,
|
||||
4.0,
|
||||
],
|
||||
head_dim=64,
|
||||
model_dim=1280,
|
||||
normalize_qk_projections=True,
|
||||
num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5],
|
||||
num_query_heads=[
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
20,
|
||||
20,
|
||||
20,
|
||||
20,
|
||||
],
|
||||
num_transformer_layers=16,
|
||||
vocab_size=32000,
|
||||
)
|
||||
|
||||
model = openelm.Model(args)
|
||||
self.model_test_runner(
|
||||
model,
|
||||
args.model_type,
|
||||
args.vocab_size,
|
||||
len(args.ffn_multipliers),
|
||||
)
|
||||
|
||||
def test_internlm2(self):
|
||||
from mlx_lm.models import internlm2
|
||||
|
||||
args = internlm2.ModelArgs(
|
||||
model_type="internlm2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10000,
|
||||
)
|
||||
model = internlm2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_llama3_1(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
max_position_embeddings=128,
|
||||
mlp_bias=False,
|
||||
num_key_value_heads=2,
|
||||
rope_scaling={
|
||||
"factor": 8.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3",
|
||||
},
|
||||
)
|
||||
model = llama.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek(self):
|
||||
from mlx_lm.models import deepseek
|
||||
|
||||
args = deepseek.ModelArgs(
|
||||
model_type="deepseek",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
model = deepseek.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek_v2(self):
|
||||
from mlx_lm.models import deepseek_v2
|
||||
|
||||
args = deepseek_v2.ModelArgs(
|
||||
model_type="deepseek_v2",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
kv_lora_rank=4,
|
||||
q_lora_rank=4,
|
||||
qk_rope_head_dim=32,
|
||||
v_head_dim=16,
|
||||
qk_nope_head_dim=32,
|
||||
rope_scaling={
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
},
|
||||
)
|
||||
model = deepseek_v2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek_v3(self):
|
||||
from mlx_lm.models import deepseek_v3
|
||||
|
||||
args = deepseek_v3.ModelArgs(
|
||||
model_type="deepseek_v3",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
n_routed_experts=4,
|
||||
n_group=2,
|
||||
topk_group=1,
|
||||
num_experts_per_tok=2,
|
||||
n_shared_experts=1,
|
||||
kv_lora_rank=4,
|
||||
q_lora_rank=4,
|
||||
qk_rope_head_dim=32,
|
||||
v_head_dim=16,
|
||||
qk_nope_head_dim=32,
|
||||
rope_scaling={
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
},
|
||||
)
|
||||
model = deepseek_v3.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma2(self):
|
||||
from mlx_lm.models import gemma2
|
||||
|
||||
args = gemma2.ModelArgs(
|
||||
model_type="gemma2",
|
||||
hidden_size=128,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=2,
|
||||
head_dim=32,
|
||||
rms_norm_eps=1e-4,
|
||||
vocab_size=1024,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
model = gemma2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma3_text(self):
|
||||
from mlx_lm.models import gemma3_text
|
||||
|
||||
args = gemma3_text.ModelArgs(
|
||||
model_type="gemma3_text",
|
||||
hidden_size=128,
|
||||
num_hidden_layers=12,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
head_dim=32,
|
||||
rms_norm_eps=1e-4,
|
||||
num_key_value_heads=1,
|
||||
sliding_window=1024,
|
||||
sliding_window_pattern=6,
|
||||
)
|
||||
model = gemma3_text.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gpt_bigcode(self):
|
||||
from mlx_lm.models import gpt_bigcode
|
||||
|
||||
args = gpt_bigcode.ModelArgs(
|
||||
model_type="gpt_bigcode",
|
||||
n_embd=128,
|
||||
n_layer=128,
|
||||
n_inner=256,
|
||||
n_head=4,
|
||||
n_positions=1000,
|
||||
layer_norm_epsilon=1e-5,
|
||||
vocab_size=1024,
|
||||
)
|
||||
model = gpt_bigcode.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
||||
|
||||
def test_nemotron(self):
|
||||
from mlx_lm.models import nemotron
|
||||
|
||||
args = nemotron.ModelArgs(
|
||||
model_type="nemotron",
|
||||
hidden_size=128,
|
||||
hidden_act="gelu",
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
norm_eps=1e-5,
|
||||
vocab_size=1024,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
model = nemotron.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phi3small(self):
|
||||
from mlx_lm.models import phi3small
|
||||
|
||||
args = phi3small.ModelArgs(
|
||||
model_type="phi3small",
|
||||
hidden_size=128,
|
||||
dense_attention_every_n_layers=2,
|
||||
ff_intermediate_size=256,
|
||||
gegelu_limit=1.0,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
layer_norm_epsilon=1e-4,
|
||||
vocab_size=1000,
|
||||
)
|
||||
model = phi3small.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phimoe(self):
|
||||
from mlx_lm.models import phimoe
|
||||
|
||||
args = phimoe.ModelArgs(
|
||||
model_type="phimoe",
|
||||
vocab_size=320,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
rope_scaling={
|
||||
"long_factor": [1.0] * 16,
|
||||
"long_mscale": 1.243163121016122,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"short_factor": [1.0] * 16,
|
||||
"short_mscale": 1.243163121016122,
|
||||
"type": "longrope",
|
||||
},
|
||||
)
|
||||
model = phimoe.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_recurrent_gemma(self):
|
||||
from mlx_lm.models import recurrent_gemma
|
||||
|
||||
args = recurrent_gemma.ModelArgs(
|
||||
model_type="recurrent_gemma",
|
||||
hidden_size=128,
|
||||
attention_bias=False,
|
||||
conv1d_width=3,
|
||||
intermediate_size=256,
|
||||
logits_soft_cap=1.0,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
rms_norm_eps=1e-4,
|
||||
rope_theta=1000,
|
||||
attention_window_size=1024,
|
||||
vocab_size=1000,
|
||||
block_types=["recurrent", "recurrent", "attention"],
|
||||
)
|
||||
model = recurrent_gemma.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_hunyuan(self):
|
||||
from mlx_lm.models import hunyuan
|
||||
|
||||
args = hunyuan.ModelArgs(
|
||||
model_type="hunyuan",
|
||||
hidden_size=128,
|
||||
attention_bias=False,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
rms_norm_eps=1e-4,
|
||||
rope_theta=1000,
|
||||
vocab_size=1000,
|
||||
moe_topk=2,
|
||||
num_experts=2,
|
||||
num_shared_expert=1,
|
||||
use_mixed_mlp_moe=True,
|
||||
use_qk_norm=True,
|
||||
rope_scaling={
|
||||
"alpha": 1000.0,
|
||||
"factor": 1.0,
|
||||
"type": "dynamic",
|
||||
},
|
||||
use_cla=True,
|
||||
cla_share_factor=2,
|
||||
)
|
||||
model = hunyuan.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_olmo2(self):
|
||||
from mlx_lm.models import olmo2
|
||||
|
||||
args = olmo2.ModelArgs(
|
||||
model_type="olmo2",
|
||||
hidden_size=128,
|
||||
attention_bias=False,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
rms_norm_eps=1e-4,
|
||||
rope_theta=1000,
|
||||
vocab_size=1000,
|
||||
)
|
||||
model = olmo2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_exaone(self):
|
||||
from mlx_lm.models import exaone
|
||||
|
||||
args = exaone.ModelArgs(
|
||||
model_type="exaone",
|
||||
hidden_size=128,
|
||||
num_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=1000,
|
||||
layer_norm_epsilon=1e-4,
|
||||
rope_theta=10000,
|
||||
)
|
||||
model = exaone.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers)
|
||||
|
||||
def test_cohere2(self):
|
||||
from mlx_lm.models import cohere2
|
||||
|
||||
args = cohere2.ModelArgs(
|
||||
model_type="cohere2",
|
||||
hidden_size=4096,
|
||||
head_dim=128,
|
||||
num_hidden_layers=40,
|
||||
sliding_window=4096,
|
||||
sliding_window_pattern=4,
|
||||
)
|
||||
model = cohere2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_internlm3(self):
|
||||
from mlx_lm.models import internlm3
|
||||
|
||||
args = internlm3.ModelArgs(
|
||||
model_type="internlm3",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = internlm3.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,305 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
MambaCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
load_prompt_cache,
|
||||
make_prompt_cache,
|
||||
save_prompt_cache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.utils import generate_step, load
|
||||
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
|
||||
class TestPromptCache(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_load(self):
|
||||
cache = [KVCache() for _ in range(4)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
self.assertTrue(len(cache), len(loaded_cache))
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||
|
||||
# Test with metadata
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
metadata = {"a": "b", "c": "d"}
|
||||
save_prompt_cache(cache_file, cache, metadata)
|
||||
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
|
||||
self.assertEqual(metadata, loaded_metadata)
|
||||
|
||||
def test_save_load_rotating_cache(self):
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
|
||||
# Test with rotating cache
|
||||
cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
self.assertTrue(len(cache), len(loaded_cache))
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertEqual(c.keep, lc.keep)
|
||||
self.assertEqual(c.max_size, lc.max_size)
|
||||
self.assertEqual(c.step, lc.step)
|
||||
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||
|
||||
# Do a couple single token updates to get a rotation
|
||||
for _ in range(2):
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
x = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
k, v = c.update_and_fetch(x, x)
|
||||
lk, lv = lc.update_and_fetch(x, x)
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(k, lk))
|
||||
self.assertTrue(mx.array_equal(v, lv))
|
||||
|
||||
def test_save_load_mixed_cache(self):
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
|
||||
cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()]
|
||||
for c in cache:
|
||||
if isinstance(c, MambaCache):
|
||||
c[0] = mx.random.uniform(shape=(4, 4, 4))
|
||||
c[1] = mx.random.uniform(shape=(4, 4, 4))
|
||||
else:
|
||||
x = mx.random.uniform(shape=(4, 4, 7, 4))
|
||||
y = mx.random.uniform(shape=(4, 4, 7, 4))
|
||||
c.update_and_fetch(x, y)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
if isinstance(c, MambaCache):
|
||||
self.assertTrue(mx.array_equal(c[0], lc[0]))
|
||||
self.assertTrue(mx.array_equal(c[1], lc[1]))
|
||||
else:
|
||||
x = mx.random.uniform(shape=(4, 4, 1, 4))
|
||||
y = mx.random.uniform(shape=(4, 4, 1, 4))
|
||||
k, v = c.update_and_fetch(x, y)
|
||||
lk, lv = lc.update_and_fetch(x, y)
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(k, lk))
|
||||
self.assertTrue(mx.array_equal(v, lv))
|
||||
|
||||
def test_cache_with_generate(self):
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||
results = list(generate_step(prompt, model, max_tokens=4))
|
||||
toks, all_logits = zip(*results)
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
i = 0
|
||||
for tok, logits in generate_step(
|
||||
prompt, model, prompt_cache=prompt_cache, max_tokens=2
|
||||
):
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||
i += 1
|
||||
|
||||
for tok, logits in generate_step(
|
||||
mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
|
||||
):
|
||||
i += 1
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||
|
||||
def test_trim_cache(self):
|
||||
cache = [KVCache() for _ in range(2)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
# Trim
|
||||
num_trimmed = trim_prompt_cache(cache, 7)
|
||||
self.assertEqual(num_trimmed, 7)
|
||||
|
||||
# Trim more tokens than remain
|
||||
num_trimmed = trim_prompt_cache(cache, 4)
|
||||
self.assertEqual(num_trimmed, 3)
|
||||
|
||||
# Can't trim mamba cache
|
||||
cache = [MambaCache() for _ in range(2)]
|
||||
for c in cache:
|
||||
c.state = mx.zeros((5, 5))
|
||||
num_trimmed = trim_prompt_cache(cache, 7)
|
||||
self.assertEqual(num_trimmed, 0)
|
||||
|
||||
# All cache's have to be trimmable
|
||||
cache = [MambaCache(), KVCache()]
|
||||
cache[0].state = mx.zeros((5, 5))
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
cache[1].update_and_fetch(x, x)
|
||||
num_trimmed = trim_prompt_cache(cache, 1)
|
||||
self.assertEqual(num_trimmed, 0)
|
||||
|
||||
cache = [RotatingKVCache(max_size=6) for _ in range(2)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 5, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
num_trimmed = trim_prompt_cache(cache, 4)
|
||||
self.assertEqual(num_trimmed, 4)
|
||||
|
||||
# Can't trim fixed-size KV cache after processing
|
||||
# more than max_kv_size tokens
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
num_trimmed = trim_prompt_cache(cache, 4)
|
||||
self.assertEqual(num_trimmed, 0)
|
||||
|
||||
cache = [QuantizedKVCache() for _ in range(2)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 64))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
num_trimmed = trim_prompt_cache(cache, 7)
|
||||
self.assertEqual(num_trimmed, 7)
|
||||
|
||||
# Trim more tokens than remain
|
||||
num_trimmed = trim_prompt_cache(cache, 4)
|
||||
self.assertEqual(num_trimmed, 3)
|
||||
|
||||
def test_trim_cache_with_generate(self):
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# Generate one token so we process the full prompt
|
||||
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
|
||||
last_tok = mx.array([last_tok])
|
||||
|
||||
# Generate two more tokens
|
||||
results = zip(
|
||||
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
||||
)
|
||||
toks, all_logits = zip(*(r[1] for r in results))
|
||||
|
||||
# To get back to the cache just after processing the prompt,
|
||||
# trim by 3 tokens
|
||||
trim_prompt_cache(prompt_cache, 3)
|
||||
|
||||
# Generate the same thing again
|
||||
results = zip(
|
||||
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
||||
)
|
||||
second_toks, second_all_logits = zip(*(r[1] for r in results))
|
||||
self.assertEqual(toks, second_toks)
|
||||
self.assertTrue(
|
||||
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
||||
)
|
||||
|
||||
def test_cache_copying(self):
|
||||
cache = [KVCache()]
|
||||
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
cache[0].update_and_fetch(x, x)
|
||||
|
||||
y = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
cache[0].update_and_fetch(y, y)
|
||||
|
||||
old_cache = copy.deepcopy(cache)
|
||||
|
||||
trim_prompt_cache(cache, 1)
|
||||
|
||||
self.assertTrue(old_cache[0].offset, 11)
|
||||
self.assertTrue(cache[0].offset, 10)
|
||||
|
||||
z = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
cache[0].update_and_fetch(z, z)
|
||||
|
||||
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
|
||||
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
|
||||
|
||||
def test_save_load_quantized_cache(self):
|
||||
cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 32))
|
||||
c.update_and_fetch(x, x)
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
self.assertTrue(loaded_cache[0].bits == cache[0].bits)
|
||||
self.assertTrue(loaded_cache[0].group_size == cache[0].group_size)
|
||||
self.assertTrue(len(cache), len(loaded_cache))
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
# Loop over quantized tuple
|
||||
for i in range(3):
|
||||
self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i]))
|
||||
self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i]))
|
||||
|
||||
# Test with metadata
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
metadata = {"a": "b", "c": "d"}
|
||||
save_prompt_cache(cache_file, cache, metadata)
|
||||
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
|
||||
self.assertEqual(metadata, loaded_metadata)
|
||||
|
||||
def test_cache_to_quantized(self):
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||
results = zip(range(4), generate_step(prompt, model))
|
||||
toks, all_logits = zip(*(r[1] for r in results))
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
i = 0
|
||||
for _, (tok, logits) in zip(
|
||||
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
||||
):
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||
i += 1
|
||||
|
||||
prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache]
|
||||
|
||||
for _, (tok, logits) in zip(
|
||||
range(1),
|
||||
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
||||
):
|
||||
i += 1
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=3e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,98 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.sample_utils import apply_min_p, apply_top_k, apply_top_p
|
||||
|
||||
|
||||
class TestSampleUtils(unittest.TestCase):
|
||||
def test_apply_top_p(self):
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
|
||||
new_logits = apply_top_p(logits, 0.3)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
new_logits = apply_top_p(logits, 0.95)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
self.assertTrue(mx.allclose(probs.squeeze(), actual_probs))
|
||||
|
||||
probs = mx.array([0.0, 0.5, 0.4, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
new_logits = apply_top_p(logits, 0.4)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
new_logits = apply_top_p(logits, 0.6)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
self.assertEqual(
|
||||
[round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0]
|
||||
)
|
||||
|
||||
new_logits = apply_top_p(logits, 0.95)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
actual_rounded = [round(p, 4) for p in actual_probs.tolist()]
|
||||
expected_rounded = [0.0, 0.5, 0.4, 0.1]
|
||||
self.assertEqual(actual_rounded, expected_rounded)
|
||||
self.assertAlmostEqual(sum(actual_probs.tolist()), 1.0)
|
||||
|
||||
# Batch mode works
|
||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]])
|
||||
logits = mx.log(probs)
|
||||
new_logits = apply_top_p(logits, 0.5)
|
||||
actual_probs = mx.softmax(new_logits, axis=-1)
|
||||
self.assertEqual(
|
||||
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
|
||||
)
|
||||
|
||||
def test_apply_min_p(self):
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
new_logits = apply_min_p(logits, 0.8)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
new_logits = apply_min_p(logits, 0.05)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
self.assertTrue(mx.allclose(actual_probs, mx.squeeze(probs)))
|
||||
|
||||
# Batch mode works
|
||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||
logits = mx.log(probs)
|
||||
new_logits = apply_min_p(logits, 0.7)
|
||||
actual_probs = mx.softmax(new_logits, axis=-1)
|
||||
self.assertEqual(
|
||||
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
|
||||
)
|
||||
|
||||
def test_apply_top_k(self):
|
||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||
logits = mx.log(probs)
|
||||
|
||||
new_logits = apply_top_k(logits, 1)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
probs = mx.array([0.6, 0.0, 0.1, 0.3])[None]
|
||||
logits = mx.log(probs)
|
||||
new_logits = apply_top_k(logits, 2)
|
||||
actual_probs = mx.softmax(new_logits.squeeze())
|
||||
self.assertEqual(
|
||||
[round(p, 4) for p in actual_probs.tolist()], [0.6667, 0.0, 0.0, 0.3333]
|
||||
)
|
||||
|
||||
# Batch mode works
|
||||
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||
logits = mx.log(probs)
|
||||
|
||||
new_logits = apply_top_k(logits, 1)
|
||||
actual_probs = mx.softmax(new_logits, axis=-1)
|
||||
self.assertEqual(
|
||||
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,133 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import http
|
||||
import json
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from mlx_lm.server import APIHandler
|
||||
from mlx_lm.utils import load
|
||||
|
||||
|
||||
class DummyModelProvider:
|
||||
def __init__(self):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||
self.model_key = (HF_MODEL_PATH, None)
|
||||
|
||||
def load(self, model, adapter=None):
|
||||
assert model in ["default_model", "chat_model"]
|
||||
return self.model, self.tokenizer
|
||||
|
||||
|
||||
class TestServer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model_provider = DummyModelProvider()
|
||||
cls.server_address = ("localhost", 0)
|
||||
cls.httpd = http.server.HTTPServer(
|
||||
cls.server_address,
|
||||
lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs),
|
||||
)
|
||||
cls.port = cls.httpd.server_port
|
||||
cls.server_thread = threading.Thread(target=cls.httpd.serve_forever)
|
||||
cls.server_thread.daemon = True
|
||||
cls.server_thread.start()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.httpd.shutdown()
|
||||
cls.httpd.server_close()
|
||||
cls.server_thread.join()
|
||||
|
||||
def test_handle_completions(self):
|
||||
url = f"http://localhost:{self.port}/v1/completions"
|
||||
|
||||
post_data = {
|
||||
"model": "default_model",
|
||||
"prompt": "Once upon a time",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.1,
|
||||
"repetition_context_size": 20,
|
||||
"stop": "stop sequence",
|
||||
}
|
||||
|
||||
response = requests.post(url, json=post_data)
|
||||
|
||||
response_body = response.text
|
||||
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_handle_chat_completions(self):
|
||||
url = f"http://localhost:{self.port}/v1/chat/completions"
|
||||
chat_post_data = {
|
||||
"model": "chat_model",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.85,
|
||||
"repetition_penalty": 1.2,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
}
|
||||
response = requests.post(url, json=chat_post_data)
|
||||
response_body = response.text
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_handle_chat_completions_with_content_fragments(self):
|
||||
url = f"http://localhost:{self.port}/v1/chat/completions"
|
||||
chat_post_data = {
|
||||
"model": "chat_model",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.85,
|
||||
"repetition_penalty": 1.2,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
||||
],
|
||||
}
|
||||
response = requests.post(url, json=chat_post_data)
|
||||
response_body = response.text
|
||||
self.assertIn("id", response_body)
|
||||
self.assertIn("choices", response_body)
|
||||
|
||||
def test_handle_models(self):
|
||||
url = f"http://localhost:{self.port}/v1/models"
|
||||
response = requests.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
response_body = json.loads(response.text)
|
||||
self.assertEqual(response_body["object"], "list")
|
||||
self.assertIsInstance(response_body["data"], list)
|
||||
self.assertGreater(len(response_body["data"]), 0)
|
||||
model = response_body["data"][0]
|
||||
self.assertIn("id", model)
|
||||
self.assertEqual(model["object"], "model")
|
||||
self.assertIn("created", model)
|
||||
|
||||
def test_sequence_overlap(self):
|
||||
from mlx_lm.server import sequence_overlap
|
||||
|
||||
self.assertTrue(sequence_overlap([1], [1]))
|
||||
self.assertTrue(sequence_overlap([1, 2], [1, 2]))
|
||||
self.assertTrue(sequence_overlap([1, 3], [3, 4]))
|
||||
self.assertTrue(sequence_overlap([1, 2, 3], [2, 3]))
|
||||
|
||||
self.assertFalse(sequence_overlap([1], [2]))
|
||||
self.assertFalse(sequence_overlap([1, 2], [3, 4]))
|
||||
self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,98 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx_lm.tokenizer_utils import (
|
||||
BPEStreamingDetokenizer,
|
||||
NaiveStreamingDetokenizer,
|
||||
SPMStreamingDetokenizer,
|
||||
load_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenizers(unittest.TestCase):
|
||||
|
||||
def download_tokenizer(self, repo):
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo,
|
||||
allow_patterns=[
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.model",
|
||||
],
|
||||
)
|
||||
)
|
||||
return load_tokenizer(path)
|
||||
|
||||
def check_tokenizer(self, tokenizer):
|
||||
def check(tokens):
|
||||
expected_text = tokenizer.decode(tokens)
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
text = ""
|
||||
for e, t in enumerate(tokens):
|
||||
detokenizer.add_token(t)
|
||||
seg = detokenizer.last_segment
|
||||
text += seg
|
||||
self.assertEqual(detokenizer.tokens, tokens[: e + 1])
|
||||
detokenizer.finalize()
|
||||
text += detokenizer.last_segment
|
||||
self.assertEqual(text, expected_text)
|
||||
|
||||
tokens = tokenizer.encode("こんにちは!私の名前はAI")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("a ,b")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("3 3")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("import 'package:flutter/material.dart';")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("hello\nworld")
|
||||
check(tokens)
|
||||
|
||||
def test_tokenizers(self):
|
||||
tokenizer_repos = [
|
||||
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
||||
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
|
||||
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
||||
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
||||
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
||||
("mlx-community/Falcon3-7B-Instruct-4bit", BPEStreamingDetokenizer),
|
||||
]
|
||||
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
||||
with self.subTest(tokenizer=tokenizer_repo):
|
||||
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||||
tokenizer.decode([0, 1, 2])
|
||||
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
|
||||
self.check_tokenizer(tokenizer)
|
||||
|
||||
# Try one with a naive detokenizer
|
||||
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
||||
self.check_tokenizer(tokenizer)
|
||||
|
||||
def test_special_tokens(self):
|
||||
tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx"
|
||||
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
detokenizer.add_token(tokenizer.eos_token_id)
|
||||
detokenizer.finalize()
|
||||
|
||||
self.assertEqual(detokenizer.last_segment, tokenizer.eos_token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,85 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.tuner.lora import LoRALinear
|
||||
from mlx_lm.tuner.utils import print_trainable_parameters
|
||||
|
||||
|
||||
class TestTunerUtils(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.capturedOutput = StringIO()
|
||||
sys.stdout = self.capturedOutput
|
||||
|
||||
def tearDown(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
def test_quantized_print_trainable_parameters(self):
|
||||
model = MagicMock()
|
||||
quantized_linear = MagicMock(spec=nn.QuantizedLinear)
|
||||
quantized_linear.weight = MagicMock(size=1e6)
|
||||
quantized_linear.bits = 8
|
||||
lora_linear = MagicMock(spec=LoRALinear)
|
||||
lora_linear.weight = MagicMock(size=2e6)
|
||||
lora_linear.parameters.return_value = [lora_linear.weight]
|
||||
|
||||
linear = MagicMock(spec=nn.Linear)
|
||||
linear.weight = MagicMock(size=3e6)
|
||||
linear.parameters.return_value = [linear.weight]
|
||||
|
||||
model.leaf_modules.return_value = {
|
||||
"quantized_linear": quantized_linear,
|
||||
"lora_linear": lora_linear,
|
||||
"linear": linear,
|
||||
}
|
||||
|
||||
model.trainable_parameters.return_value = {
|
||||
"layer1.weight": MagicMock(size=1e6),
|
||||
"layer3.weight": MagicMock(size=2e6),
|
||||
}
|
||||
expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n"
|
||||
print_trainable_parameters(model)
|
||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits)
|
||||
self.capturedOutput.truncate(0)
|
||||
self.capturedOutput.seek(0)
|
||||
|
||||
quantized_linear.weight = MagicMock(size=1e6)
|
||||
quantized_linear.bits = 4
|
||||
expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n"
|
||||
print_trainable_parameters(model)
|
||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits)
|
||||
self.capturedOutput.truncate(0)
|
||||
self.capturedOutput.seek(0)
|
||||
|
||||
def test_print_trainable_parameters(self):
|
||||
model = MagicMock()
|
||||
linear1 = MagicMock(spec=nn.Linear)
|
||||
linear1.weight = MagicMock(size=1e6)
|
||||
linear1.parameters.return_value = [linear1.weight]
|
||||
linear2 = MagicMock(spec=nn.Linear)
|
||||
linear2.weight = MagicMock(size=2e6)
|
||||
linear2.parameters.return_value = [linear2.weight]
|
||||
lora_linear = MagicMock(spec=LoRALinear)
|
||||
lora_linear.weight = MagicMock(size=3e6)
|
||||
lora_linear.parameters.return_value = [lora_linear.weight]
|
||||
model.leaf_modules.return_value = {
|
||||
"linear1": linear1,
|
||||
"linear2": linear2,
|
||||
"lora_linear": lora_linear,
|
||||
}
|
||||
|
||||
model.trainable_parameters.return_value = {
|
||||
"layer1.weight": MagicMock(size=1e6),
|
||||
"layer3.weight": MagicMock(size=2e6),
|
||||
}
|
||||
expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n"
|
||||
print_trainable_parameters(model)
|
||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import utils
|
||||
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir_fid.name)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_load(self):
|
||||
model, _ = utils.load(HF_MODEL_PATH)
|
||||
|
||||
model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True)
|
||||
|
||||
mx.eval(model_lazy.parameters())
|
||||
|
||||
p1 = model.layers[0].mlp.up_proj.weight
|
||||
p2 = model_lazy.layers[0].mlp.up_proj.weight
|
||||
self.assertTrue(mx.allclose(p1, p2))
|
||||
|
||||
def test_make_shards(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=2048,
|
||||
num_hidden_layers=32,
|
||||
intermediate_size=4096,
|
||||
num_attention_heads=32,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=30_000,
|
||||
)
|
||||
model = llama.Model(args)
|
||||
weights = tree_flatten(model.parameters())
|
||||
gb = sum(p.nbytes for _, p in weights) // 2**30
|
||||
shards = utils.make_shards(dict(weights), 1)
|
||||
self.assertTrue(gb <= len(shards) <= gb + 1)
|
||||
|
||||
def test_quantize(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = llama.Model(args)
|
||||
weights, config = utils.quantize_model(model, {}, 64, 4)
|
||||
self.assertTrue("model.layers.2.mlp.up_proj.scales" in weights)
|
||||
self.assertTrue("model.layers.2.mlp.up_proj.biases" in weights)
|
||||
self.assertEqual(config["quantization"]["group_size"], 64)
|
||||
self.assertEqual(config["quantization"]["bits"], 4)
|
||||
|
||||
def test_convert(self):
|
||||
mlx_path = os.path.join(self.test_dir, "mlx_model")
|
||||
|
||||
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, quantize=True)
|
||||
model, _ = utils.load(mlx_path)
|
||||
self.assertTrue(isinstance(model.layers[0].mlp.up_proj, nn.QuantizedLinear))
|
||||
self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear))
|
||||
|
||||
# Check model weights have right type
|
||||
mlx_path = os.path.join(self.test_dir, "mlx_model_bf16")
|
||||
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16")
|
||||
model, _ = utils.load(mlx_path)
|
||||
|
||||
self.assertEqual(model.layers[0].mlp.up_proj.weight.dtype, mx.bfloat16)
|
||||
self.assertEqual(model.layers[-1].mlp.up_proj.weight.dtype, mx.bfloat16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,50 +0,0 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.qwen2 import Model as Qwen2Model
|
||||
from mlx_lm.utils import get_model_path, load_model
|
||||
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
|
||||
class TestLoadModelCustomGetClasses(unittest.TestCase):
|
||||
|
||||
def test_load_model_with_custom_get_classes(self):
|
||||
class CustomQwenModel(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.config = args
|
||||
self.custom_attribute = "This is a custom model"
|
||||
|
||||
def load_weights(self, weights, **kwargs):
|
||||
self.qwenWeights = weights
|
||||
|
||||
class CustomQwenConfig:
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
instance = cls()
|
||||
for k, v in config.items():
|
||||
setattr(instance, k, v)
|
||||
return instance
|
||||
|
||||
def custom_get_classes(config):
|
||||
return CustomQwenModel, CustomQwenConfig
|
||||
|
||||
model_path = get_model_path(HF_MODEL_PATH)
|
||||
model, _ = load_model(model_path, get_model_classes=custom_get_classes)
|
||||
|
||||
self.assertIsInstance(model, CustomQwenModel)
|
||||
self.assertTrue(hasattr(model, "custom_attribute"))
|
||||
self.assertEqual(model.custom_attribute, "This is a custom model")
|
||||
self.assertTrue(hasattr(model, "qwenWeights"))
|
||||
|
||||
def test_load_model_with_default_get_classes(self):
|
||||
model_path = get_model_path(HF_MODEL_PATH)
|
||||
model, _ = load_model(model_path)
|
||||
|
||||
self.assertIsInstance(model, Qwen2Model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user