diff --git a/llms/CONTRIBUTING.md b/llms/CONTRIBUTING.md deleted file mode 100644 index d85067cc..00000000 --- a/llms/CONTRIBUTING.md +++ /dev/null @@ -1,47 +0,0 @@ -# Contributing to MLX LM - -Below are some tips to port LLMs available on Hugging Face to MLX. - -Before starting checkout the [general contribution -guidelines](https://github.com/ml-explore/mlx-examples/blob/main/CONTRIBUTING.md). - -Next, from this directory, do an editable install: - -```shell -pip install -e . -``` - -Then check if the model has weights in the -[safetensors](https://huggingface.co/docs/safetensors/index) format. If not -[follow instructions](https://huggingface.co/spaces/safetensors/convert) to -convert it. - -After that, add the model file to the -[`mlx_lm/models`](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/models) -directory. You can see other examples there. We recommend starting from a model -that is similar to the model you are porting. - -Make sure the name of the new model file is the same as the `model_type` in the -`config.json`, for example -[starcoder2](https://huggingface.co/bigcode/starcoder2-7b/blob/main/config.json#L17). - -To determine the model layer names, we suggest either: - -- Refer to the Transformers implementation if you are familiar with the - codebase. -- Load the model weights and check the weight names which will tell you about - the model structure. -- Look at the names of the weights by inspecting `model.safetensors.index.json` - in the Hugging Face repo. - -To add LoRA support edit -[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60) - -Finally, add a test for the new modle type to the [model -tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py). - -From the `llms/` directory, you can run the tests with: - -```shell -python -m unittest discover tests/ -``` diff --git a/llms/MANIFEST.in b/llms/MANIFEST.in deleted file mode 100644 index 05b93159..00000000 --- a/llms/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -include mlx_lm/requirements.txt -recursive-include mlx_lm/ *.py diff --git a/llms/README.md b/llms/README.md index c9283a0d..10751c98 100644 --- a/llms/README.md +++ b/llms/README.md @@ -1,300 +1,6 @@ -# DEPRECATED +# MOVE NOTICE The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm). -The package here will be removed soon. Send new contributions and issues to the MLX LM repo. - -## Generate Text with LLMs and MLX - -The easiest way to get started is to install the `mlx-lm` package: - -**With `pip`**: - -```sh -pip install mlx-lm -``` - -**With `conda`**: - -```sh -conda install -c conda-forge mlx-lm -``` - -The `mlx-lm` package also has: - -- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md) -- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md) -- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md) - -### Quick Start - -To generate text with an LLM use: - -```bash -mlx_lm.generate --prompt "Hi!" -``` - -To chat with an LLM use: - -```bash -mlx_lm.chat -``` - -This will give you a chat REPL that you can use to interact with the LLM. The -chat context is preserved during the lifetime of the REPL. - -Commands in `mlx-lm` typically take command line options which let you specify -the model, sampling parameters, and more. Use `-h` to see a list of available -options for a command, e.g.: - -```bash -mlx_lm.generate -h -``` - -### Python API - -You can use `mlx-lm` as a module: - -```python -from mlx_lm import load, generate - -model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") - -prompt = "Write a story about Einstein" - -messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template( - messages, add_generation_prompt=True -) - -text = generate(model, tokenizer, prompt=prompt, verbose=True) -``` - -To see a description of all the arguments you can do: - -``` ->>> help(generate) -``` - -Check out the [generation -example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) -to see how to use the API in more detail. - -The `mlx-lm` package also comes with functionality to quantize and optionally -upload models to the Hugging Face Hub. - -You can convert models using the Python API: - -```python -from mlx_lm import convert - -repo = "mistralai/Mistral-7B-Instruct-v0.3" -upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit" - -convert(repo, quantize=True, upload_repo=upload_repo) -``` - -This will generate a 4-bit quantized Mistral 7B and upload it to the repo -`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the -converted model in the path `mlx_model` by default. - -To see a description of all the arguments you can do: - -``` ->>> help(convert) -``` - -#### Streaming - -For streaming generation, use the `stream_generate` function. This yields -a generation response object. - -For example, - -```python -from mlx_lm import load, stream_generate - -repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit" -model, tokenizer = load(repo) - -prompt = "Write a story about Einstein" - -messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template( - messages, add_generation_prompt=True -) - -for response in stream_generate(model, tokenizer, prompt, max_tokens=512): - print(response.text, end="", flush=True) -print() -``` - -#### Sampling - -The `generate` and `stream_generate` functions accept `sampler` and -`logits_processors` keyword arguments. A sampler is any callable which accepts -a possibly batched logits array and returns an array of sampled tokens. The -`logits_processors` must be a list of callables which take the token history -and current logits as input and return the processed logits. The logits -processors are applied in order. - -Some standard sampling functions and logits processors are provided in -`mlx_lm.sample_utils`. - -### Command Line - -You can also use `mlx-lm` from the command line with: - -``` -mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello" -``` - -This will download a Mistral 7B model from the Hugging Face Hub and generate -text using the given prompt. - -For a full list of options run: - -``` -mlx_lm.generate --help -``` - -To quantize a model from the command line run: - -``` -mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q -``` - -For more options run: - -``` -mlx_lm.convert --help -``` - -You can upload new models to Hugging Face by specifying `--upload-repo` to -`convert`. For example, to upload a quantized Mistral-7B model to the -[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do: - -``` -mlx_lm.convert \ - --hf-path mistralai/Mistral-7B-Instruct-v0.3 \ - -q \ - --upload-repo mlx-community/my-4bit-mistral -``` - -Models can also be converted and quantized directly in the -[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging -Face Space. - -### Long Prompts and Generations - -`mlx-lm` has some tools to scale efficiently to long prompts and generations: - -- A rotating fixed-size key-value cache. -- Prompt caching - -To use the rotating key-value cache pass the argument `--max-kv-size n` where -`n` can be any integer. Smaller values like `512` will use very little RAM but -result in worse quality. Larger values like `4096` or higher will use more RAM -but have better quality. - -Caching prompts can substantially speedup reusing the same long context with -different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example: - -```bash -cat prompt.txt | mlx_lm.cache_prompt \ - --model mistralai/Mistral-7B-Instruct-v0.3 \ - --prompt - \ - --prompt-cache-file mistral_prompt.safetensors -``` - -Then use the cached prompt with `mlx_lm.generate`: - -``` -mlx_lm.generate \ - --prompt-cache-file mistral_prompt.safetensors \ - --prompt "\nSummarize the above text." -``` - -The cached prompt is treated as a prefix to the supplied prompt. Also notice -when using a cached prompt, the model to use is read from the cache and need -not be supplied explicitly. - -Prompt caching can also be used in the Python API in order to to avoid -recomputing the prompt. This is useful in multi-turn dialogues or across -requests that use the same context. See the -[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py) -for more usage details. - -### Supported Models - -`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to -run is not supported, file an -[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, -submit a pull request. - -Here are a few examples of Hugging Face models that work with this example: - -- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) -- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) -- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) -- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat) -- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) -- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) -- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B) -- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b) -- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct) -- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) -- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) -- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) - -Most -[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), -[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending), -[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending), -and -[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending) -style models should work out of the box. - -For some models (such as `Qwen` and `plamo`) the tokenizer requires you to -enable the `trust_remote_code` option. You can do this by passing -`--trust-remote-code` in the command line. If you don't specify the flag -explicitly, you will be prompted to trust remote code in the terminal when -running the model. - -For `Qwen` models you must also specify the `eos_token`. You can do this by -passing `--eos-token "<|endoftext|>"` in the command -line. - -These options can also be set in the Python API. For example: - -```python -model, tokenizer = load( - "qwen/Qwen-7B", - tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True}, -) -``` - -### Large Models - -> [!NOTE] - This requires macOS 15.0 or higher to work. - -Models which are large relative to the total RAM available on the machine can -be slow. `mlx-lm` will attempt to make them faster by wiring the memory -occupied by the model and cache. This requires macOS 15 or higher to -work. - -If you see the following warning message: - -> [WARNING] Generating with a model that requires ... - -then the model will likely be slow on the given machine. If the model fits in -RAM then it can often be sped up by increasing the system wired memory limit. -To increase the limit, set the following `sysctl`: - -```bash -sudo sysctl iogpu.wired_limit_mb=N -``` - -The value `N` should be larger than the size of the model in megabytes but -smaller than the memory size of the machine. +The package has been removed from the MLX Examples repo. Send new contributions +and issues to the MLX LM repo. diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md deleted file mode 100644 index e863abc4..00000000 --- a/llms/mlx_lm/LORA.md +++ /dev/null @@ -1,392 +0,0 @@ -# Fine-Tuning with LoRA or QLoRA - -You can use use the `mlx-lm` package to fine-tune an LLM with low rank -adaptation (LoRA) for a target task.[^lora] The example also supports quantized -LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - -- Mistral -- Llama -- Phi2 -- Mixtral -- Qwen2 -- Gemma -- OLMo -- MiniCPM -- InternLM2 - -## Contents - -- [Run](#Run) - - [Fine-tune](#Fine-tune) - - [Evaluate](#Evaluate) - - [Generate](#Generate) -- [Fuse](#Fuse) -- [Data](#Data) -- [Memory Issues](#Memory-Issues) - -## Run - -The main command is `mlx_lm.lora`. To see a full list of command-line options run: - -```shell -mlx_lm.lora --help -``` - -Note, in the following the `--model` argument can be any compatible Hugging -Face repo or a local path to a converted model. - -You can also specify a YAML config with `-c`/`--config`. For more on the format see the -[example YAML](examples/lora_config.yaml). For example: - -```shell -mlx_lm.lora --config /path/to/config.yaml -``` - -If command-line flags are also used, they will override the corresponding -values in the config. - -### Fine-tune - -To fine-tune a model use: - -```shell -mlx_lm.lora \ - --model \ - --train \ - --data \ - --iters 600 -``` - -To fine-tune the full model weights, add the `--fine-tune-type full` flag. -Currently supported fine-tuning types are `lora` (default), `dora`, and `full`. - -The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl` -when using `--train` and a path to a `test.jsonl` when using `--test`. For more -details on the data format see the section on [Data](#Data). - -For example, to fine-tune a Mistral 7B you can use `--model -mistralai/Mistral-7B-v0.1`. - -If `--model` points to a quantized model, then the training will use QLoRA, -otherwise it will use regular LoRA. - -By default, the adapter config and learned weights are saved in `adapters/`. -You can specify the output location with `--adapter-path`. - -You can resume fine-tuning with an existing adapter with -`--resume-adapter-file `. - -#### Prompt Masking - -The default training computes a loss for every token in the sample. You can -ignore the prompt and compute loss for just the completion by passing -`--mask-prompt`. Note this is only supported for `chat` and `completion` -datasets. For `chat` datasets the final message in the message list is -considered the completion. See the [dataset section](#Data) for more details. - -### Evaluate - -To compute test set perplexity use: - -```shell -mlx_lm.lora \ - --model \ - --adapter-path \ - --data \ - --test -``` - -### Generate - -For generation use `mlx_lm.generate`: - -```shell -mlx_lm.generate \ - --model \ - --adapter-path \ - --prompt "" -``` - -## Fuse - -You can generate a model fused with the low-rank adapters using the -`mlx_lm.fuse` command. This command also allows you to optionally: - -- Upload the fused model to the Hugging Face Hub. -- Export the fused model to GGUF. Note GGUF support is limited to Mistral, - Mixtral, and Llama style models in fp16 precision. - -To see supported options run: - -```shell -mlx_lm.fuse --help -``` - -To generate the fused model run: - -```shell -mlx_lm.fuse --model -``` - -This will by default load the adapters from `adapters/`, and save the fused -model in the path `fused_model/`. All of these are configurable. - -To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments -to `mlx_lm.fuse`. The latter is the repo name of the original model, which is -useful for the sake of attribution and model versioning. - -For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: - -```shell -mlx_lm.fuse \ - --model mistralai/Mistral-7B-v0.1 \ - --upload-repo mlx-community/my-lora-mistral-7b \ - --hf-path mistralai/Mistral-7B-v0.1 -``` - -To export a fused model to GGUF, run: - -```shell -mlx_lm.fuse \ - --model mistralai/Mistral-7B-v0.1 \ - --export-gguf -``` - -This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You -can specify the file name with `--gguf-path`. - -## Data - -The LoRA command expects you to provide a dataset with `--data`. The MLX -Examples GitHub repo has an [example of the WikiSQL -data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the -correct format. - -Datasets can be specified in `*.jsonl` files locally or loaded from Hugging -Face. - -### Local Datasets - -For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a -`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data -loader expects a `test.jsonl` in the data directory. - -Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text` -data formats. Here are examples of these formats: - -`chat`: - -```jsonl -{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]} -``` - -`tools`: - -```jsonl -{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]} -``` - -
-View the expanded single data tool format - -```jsonl -{ - "messages": [ - { "role": "user", "content": "What is the weather in San Francisco?" }, - { - "role": "assistant", - "tool_calls": [ - { - "id": "call_id", - "type": "function", - "function": { - "name": "get_current_weather", - "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}" - } - } - ] - } - ], - "tools": [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and country, eg. San Francisco, USA" - }, - "format": { "type": "string", "enum": ["celsius", "fahrenheit"] } - }, - "required": ["location", "format"] - } - } - } - ] -} -``` - - -The format for the `arguments` field in a function varies for different models. -Common formats include JSON strings and dictionaries. The example provided -follows the format used by -[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples) -and [Mistral -AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct). -A dictionary format is used in Hugging Face's [chat -templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example). -Refer to the documentation for the model you are fine-tuning for more details. - -
- -`completions`: - -```jsonl -{"prompt": "What is the capital of France?", "completion": "Paris."} -``` - -For the `completions` data format, a different key can be used for the prompt -and completion by specifying the following in the YAML config: - -```yaml -prompt_feature: "input" -completion_feature: "output" -``` - -Here, `"input"` is the expected key instead of the default `"prompt"`, and -`"output"` is the expected key instead of `"completion"`. - -`text`: - -```jsonl -{"text": "This is an example for the model."} -``` - -Note, the format is automatically determined by the dataset. Note also, keys -in each line not expected by the loader will be ignored. - -> [!NOTE] -> Each example in the datasets must be on a single line. Do not put more than -> one example per line and do not split an example across multiple lines. - -### Hugging Face Datasets - -To use Hugging Face datasets, first install the `datasets` package: - -``` -pip install datasets -``` - -If the Hugging Face dataset is already in a supported format, you can specify -it on the command line. For example, pass `--data mlx-community/wikisql` to -train on the pre-formatted WikiwSQL data. - -Otherwise, provide a mapping of keys in the dataset to the features MLX LM -expects. Use a YAML config to specify the Hugging Face dataset arguments. For -example: - -```yaml -hf_dataset: - name: "billsum" - prompt_feature: "text" - completion_feature: "summary" -``` - -- Use `prompt_feature` and `completion_feature` to specify keys for a - `completions` dataset. Use `text_feature` to specify the key for a `text` - dataset. Use `chat_feature` to specify the key for a chat dataset. - -- To specify the train, valid, or test splits, set the corresponding - `{train,valid,test}_split` argument. - -You can specify a list of Hugging Face datasets with a list of records each -with the same structure as above. For example: - -```yaml -hf_dataset: - - name: "Open-Orca/OpenOrca" - train_split: "train[:90%]" - valid_split: "train[-10%:]" - prompt_feature: "question" - completion_feature: "response" - - name: "trl-lib/ultrafeedback_binarized" - train_split: "train[:90%]" - valid_split: "train[-10%:]" - chat_feature: "chosen" -``` - -- Arguments specified in `config` will be passed as keyword arguments to - [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). - -In general, for the `chat`, `tools` and `completions` formats, Hugging Face -[chat -templates](https://huggingface.co/docs/transformers/main/en/chat_templating) -are used. This applies the model's chat template by default. If the model does -not have a chat template, then Hugging Face will use a default. For example, -the final text in the `chat` example above with Hugging Face's default template -becomes: - -```text -<|im_start|>system -You are a helpful assistant.<|im_end|> -<|im_start|>user -Hello.<|im_end|> -<|im_start|>assistant -How can I assistant you today.<|im_end|> -``` - -If you are unsure of the format to use, the `chat` or `completions` are good to -start with. For custom requirements on the format of the dataset, use the -`text` format to assemble the content yourself. - -## Memory Issues - -Fine-tuning a large model with LoRA requires a machine with a decent amount -of memory. Here are some tips to reduce memory use should you need to do so: - -1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model - with `convert.py` and the `-q` flag. See the [Setup](#setup) section for - more details. - -2. Try using a smaller batch size with `--batch-size`. The default is `4` so - setting this to `2` or `1` will reduce memory consumption. This may slow - things down a little, but will also reduce the memory use. - -3. Reduce the number of layers to fine-tune with `--num-layers`. The default - is `16`, so you can try `8` or `4`. This reduces the amount of memory - needed for back propagation. It may also reduce the quality of the - fine-tuned model if you are fine-tuning with a lot of data. - -4. Longer examples require more memory. If it makes sense for your data, one thing - you can do is break your examples into smaller - sequences when making the `{train, valid, test}.jsonl` files. - -5. Gradient checkpointing lets you trade-off memory use (less) for computation - (more) by recomputing instead of storing intermediate values needed by the - backward pass. You can use gradient checkpointing by passing the - `--grad-checkpoint` flag. Gradient checkpointing will be more helpful for - larger batch sizes or sequence lengths with smaller or quantized models. - -For example, for a machine with 32 GB the following should run reasonably fast: - -``` -mlx_lm.lora \ - --model mistralai/Mistral-7B-v0.1 \ - --train \ - --batch-size 1 \ - --num-layers 4 \ - --data wikisql -``` - -The above command on an M1 Max with 32 GB runs at about 250 -tokens-per-second, using the MLX Example -[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) -data set. - -[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. - -[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) diff --git a/llms/mlx_lm/MANAGE.md b/llms/mlx_lm/MANAGE.md deleted file mode 100644 index 00858a0a..00000000 --- a/llms/mlx_lm/MANAGE.md +++ /dev/null @@ -1,22 +0,0 @@ -# Managing Models - -You can use `mlx-lm` to manage models downloaded locally in your machine. They -are stored in the Hugging Face cache. - -Scan models: - -```shell -mlx_lm.manage --scan -``` - -Specify a `--pattern` to get info on a single or specific set of models: - -```shell -mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit -``` - -To delete a model (or multiple models): - -```shell -mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit -``` diff --git a/llms/mlx_lm/MERGE.md b/llms/mlx_lm/MERGE.md deleted file mode 100644 index 093c7ed6..00000000 --- a/llms/mlx_lm/MERGE.md +++ /dev/null @@ -1,50 +0,0 @@ -# Model Merging - -You can use `mlx-lm` to merge models and upload them to the Hugging -Face hub or save them locally for LoRA fine tuning. - -The main command is `mlx_lm.merge`: - -```shell -mlx_lm.merge --config config.yaml -``` - -The merged model will be saved by default in `mlx_merged_model`. To see a -full list of options run: - -```shell -mlx_lm.merge --help -``` - -Here is an example `config.yaml`: - -```yaml -models: - - OpenPipe/mistral-ft-optimized-1218 - - mlabonne/NeuralHermes-2.5-Mistral-7B -method: slerp -parameters: - t: - - filter: self_attn - value: [0, 0.5, 0.3, 0.7, 1] - - filter: mlp - value: [1, 0.5, 0.7, 0.3, 0] - - value: 0.5 -``` - -The `models` field is a list of Hugging Face repo ids. The first model in the -list is treated as the base model into which the remaining models are merged. - -The `method` field is the merging method. Right now `slerp` is the only -supported method. - -The `parameters` are the corresponding parameters for the given `method`. -Each parameter is a list with `filter` determining which layer the parameter -applies to and `value` determining the actual value used. The last item in -the list without a `filter` field is the default. - -If `value` is a list, it specifies the start and end values for the -corresponding segment of blocks. In the example above, the models have 32 -blocks. For blocks 1-8, the layers with `self_attn` in the name will use the -values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16) -will use `np.linspace(0.5, 0.3, 8)`, and so on. diff --git a/llms/mlx_lm/README.md b/llms/mlx_lm/README.md deleted file mode 100644 index fd11a8f2..00000000 --- a/llms/mlx_lm/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# DEPRECATED - -The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm). - -The package here will be removed soon. Send new contributions and issues to the MLX LM repo. - -## Generate Text with MLX and :hugs: Hugging Face - -This an example of large language model text generation that can pull models from -the Hugging Face Hub. - -For more information on this example, see the [README](../README.md) in the -parent directory. - -This package also supports fine tuning with LoRA or QLoRA. For more information -see the [LoRA documentation](LORA.md). diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md deleted file mode 100644 index e544c6fa..00000000 --- a/llms/mlx_lm/SERVER.md +++ /dev/null @@ -1,131 +0,0 @@ -# HTTP Model Server - -You use `mlx-lm` to make an HTTP API for generating text with any supported -model. The HTTP API is intended to be similar to the [OpenAI chat -API](https://platform.openai.com/docs/api-reference). - -> [!NOTE] -> The MLX LM server is not recommended for production as it only implements -> basic security checks. - -Start the server with: - -```shell -mlx_lm.server --model -``` - -For example: - -```shell -mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit -``` - -This will start a text generation server on port `8080` of the `localhost` -using Mistral 7B instruct. The model will be downloaded from the provided -Hugging Face repo if it is not already in the local cache. - -To see a full list of options run: - -```shell -mlx_lm.server --help -``` - -You can make a request to the model by running: - -```shell -curl localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [{"role": "user", "content": "Say this is a test!"}], - "temperature": 0.7 - }' -``` - -### Request Fields - -- `messages`: An array of message objects representing the conversation - history. Each message object should have a role (e.g. user, assistant) and - content (the message text). - -- `role_mapping`: (Optional) A dictionary to customize the role prefixes in - the generated prompt. If not provided, the default mappings are used. - -- `stop`: (Optional) An array of strings or a single string. These are - sequences of tokens on which the generation should stop. - -- `max_tokens`: (Optional) An integer specifying the maximum number of tokens - to generate. Defaults to `100`. - -- `stream`: (Optional) A boolean indicating if the response should be - streamed. If true, responses are sent as they are generated. Defaults to - false. - -- `temperature`: (Optional) A float specifying the sampling temperature. - Defaults to `1.0`. - -- `top_p`: (Optional) A float specifying the nucleus sampling parameter. - Defaults to `1.0`. - -- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens. - Defaults to `1.0`. - -- `repetition_context_size`: (Optional) The size of the context window for - applying repetition penalty. Defaults to `20`. - -- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias - values. Defaults to `None`. - -- `logprobs`: (Optional) An integer specifying the number of top tokens and - corresponding log probabilities to return for each output in the generated - sequence. If set, this can be any value between 1 and 10, inclusive. - -- `model`: (Optional) A string path to a local model or Hugging Face repo id. - If the path is local is must be relative to the directory the server was - started in. - -- `adapters`: (Optional) A string path to low-rank adapters. The path must be - relative to the directory the server was started in. - -### Response Fields - -- `id`: A unique identifier for the chat. - -- `system_fingerprint`: A unique identifier for the system. - -- `object`: Any of "chat.completion", "chat.completion.chunk" (for - streaming), or "text.completion". - -- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). - -- `created`: A time-stamp for when the request was processed. - -- `choices`: A list of outputs. Each output is a dictionary containing the fields: - - `index`: The index in the list. - - `logprobs`: A dictionary containing the fields: - - `token_logprobs`: A list of the log probabilities for the generated - tokens. - - `tokens`: A list of the generated token ids. - - `top_logprobs`: A list of lists. Each list contains the `logprobs` - top tokens (if requested) with their corresponding probabilities. - - `finish_reason`: The reason the completion ended. This can be either of - `"stop"` or `"length"`. - - `message`: The text response from the model. - -- `usage`: A dictionary containing the fields: - - `prompt_tokens`: The number of prompt tokens processed. - - `completion_tokens`: The number of tokens generated. - - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. - -### List Models - -Use the `v1/models` endpoint to list available models: - -```shell -curl localhost:8080/v1/models -H "Content-Type: application/json" -``` - -This will return a list of locally available models where each model in the -list contains the following fields: - -- `id`: The Hugging Face repo id. -- `created`: A time-stamp representing the model creation time. diff --git a/llms/mlx_lm/UPLOAD.md b/llms/mlx_lm/UPLOAD.md deleted file mode 100644 index f5de3655..00000000 --- a/llms/mlx_lm/UPLOAD.md +++ /dev/null @@ -1,37 +0,0 @@ -### Packaging for PyPI - -Install `build` and `twine`: - -``` -pip install --user --upgrade build -pip install --user --upgrade twine -``` - -Generate the source distribution and wheel: - -``` -python -m build -``` - -> [!warning] -> Use a test server first - -#### Test Upload - -Upload to test server: - -``` -python -m twine upload --repository testpypi dist/* -``` - -Install from test server and check that it works: - -``` -python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm -``` - -#### Upload - -``` -python -m twine upload dist/* -``` diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py deleted file mode 100644 index 538be927..00000000 --- a/llms/mlx_lm/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import os - -from ._version import __version__ - -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" - -from .utils import convert, generate, load, stream_generate diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py deleted file mode 100644 index 839089b6..00000000 --- a/llms/mlx_lm/_version.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -__version__ = "0.21.6" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py deleted file mode 100644 index fff64f78..00000000 --- a/llms/mlx_lm/cache_prompt.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import argparse -import json -import sys -import time - -import mlx.core as mx - -from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import generate_step, load - -DEFAULT_QUANTIZED_KV_START = 5000 - - -def setup_arg_parser(): - """Set up and return the argument parser.""" - parser = argparse.ArgumentParser( - description="Cache the state of a prompt to be reused with mlx_lm.generate" - ) - parser.add_argument( - "--model", - type=str, - default="mlx_model", - help="The path to the local model directory or Hugging Face repo.", - ) - parser.add_argument( - "--adapter-path", - type=str, - help="Optional path for the trained adapter weights and config.", - ) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Enable trusting remote code for tokenizer", - ) - parser.add_argument( - "--eos-token", - type=str, - default=None, - help="End of sequence token for tokenizer", - ) - parser.add_argument( - "--ignore-chat-template", - action="store_true", - help="Use the raw prompt without the tokenizer's chat template.", - ) - parser.add_argument( - "--use-default-chat-template", - action="store_true", - help="Use the default chat template", - ) - parser.add_argument( - "--max-kv-size", - type=int, - default=None, - help="Set the maximum key-value cache size", - ) - parser.add_argument( - "--prompt-cache-file", - help="The file to save the prompt cache in", - required=True, - ) - parser.add_argument( - "--prompt", - required=True, - help="Message to be processed by the model ('-' reads from stdin)", - ) - parser.add_argument( - "--kv-bits", - type=int, - help="Number of bits for KV cache quantization. " - "Defaults to no quantization.", - default=None, - ) - parser.add_argument( - "--kv-group-size", - type=int, - help="Group size for KV cache quantization.", - default=64, - ) - parser.add_argument( - "--quantized-kv-start", - help="When --kv-bits is set, start quantizing the KV cache " - "from this step onwards.", - type=int, - default=DEFAULT_QUANTIZED_KV_START, - ) - return parser - - -def main(): - parser = setup_arg_parser() - args = parser.parse_args() - - # Building tokenizer_config - tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} - if args.eos_token is not None: - tokenizer_config["eos_token"] = args.eos_token - - model, tokenizer = load( - args.model, - adapter_path=args.adapter_path, - tokenizer_config=tokenizer_config, - ) - - args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt - - if args.use_default_chat_template: - if tokenizer.chat_template is None: - tokenizer.chat_template = tokenizer.default_chat_template - - if not args.ignore_chat_template and tokenizer.chat_template is not None: - messages = [{"role": "user", "content": args.prompt}] - prompt = tokenizer.apply_chat_template( - messages, add_generation_prompt=False, continue_final_message=True - ) - - else: - prompt = tokenizer.encode(args.prompt) - - cache = make_prompt_cache(model, args.max_kv_size) - y = mx.array(prompt) - - # Process the prompt - start = time.time() - max_msg_len = 0 - - def callback(processed, total_tokens): - current = time.time() - speed = processed / (current - start) - msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" - nonlocal max_msg_len - max_msg_len = max(max_msg_len, len(msg)) - print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) - - for _ in generate_step( - y, - model, - max_tokens=0, - prompt_cache=cache, - kv_bits=args.kv_bits, - kv_group_size=args.kv_group_size, - quantized_kv_start=args.quantized_kv_start, - prompt_progress_callback=callback, - ): - pass - - print() - print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") - - print("Saving...") - metadata = {} - metadata["model"] = args.model - metadata["chat_template"] = json.dumps(tokenizer.chat_template) - metadata["tokenizer_config"] = json.dumps(tokenizer_config) - save_prompt_cache(args.prompt_cache_file, cache, metadata) - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py deleted file mode 100644 index d8e1ccb9..00000000 --- a/llms/mlx_lm/chat.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import argparse -import json - -import mlx.core as mx - -from .models.cache import make_prompt_cache -from .sample_utils import make_sampler -from .utils import load, stream_generate - -DEFAULT_TEMP = 0.0 -DEFAULT_TOP_P = 1.0 -DEFAULT_SEED = None -DEFAULT_MAX_TOKENS = 256 -DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" - - -def setup_arg_parser(): - """Set up and return the argument parser.""" - parser = argparse.ArgumentParser(description="Chat with an LLM") - parser.add_argument( - "--model", - type=str, - help="The path to the local model directory or Hugging Face repo.", - default=DEFAULT_MODEL, - ) - parser.add_argument( - "--adapter-path", - type=str, - help="Optional path for the trained adapter weights and config.", - ) - parser.add_argument( - "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" - ) - parser.add_argument( - "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" - ) - parser.add_argument( - "--seed", - type=int, - default=DEFAULT_SEED, - help="PRNG seed", - ) - parser.add_argument( - "--max-kv-size", - type=int, - help="Set the maximum key-value cache size", - default=None, - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=DEFAULT_MAX_TOKENS, - help="Maximum number of tokens to generate", - ) - return parser - - -def main(): - parser = setup_arg_parser() - args = parser.parse_args() - - if args.seed is not None: - mx.random.seed(args.seed) - - model, tokenizer = load( - args.model, - adapter_path=args.adapter_path, - tokenizer_config={"trust_remote_code": True}, - ) - - def print_help(): - print("The command list:") - print("- 'q' to exit") - print("- 'r' to reset the chat") - print("- 'h' to display these commands") - - print(f"[INFO] Starting chat session with {args.model}.") - print_help() - prompt_cache = make_prompt_cache(model, args.max_kv_size) - while True: - query = input(">> ") - if query == "q": - break - if query == "r": - prompt_cache = make_prompt_cache(model, args.max_kv_size) - continue - if query == "h": - print_help() - continue - messages = [{"role": "user", "content": query}] - prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - for response in stream_generate( - model, - tokenizer, - prompt, - max_tokens=args.max_tokens, - sampler=make_sampler(args.temp, args.top_p), - prompt_cache=prompt_cache, - ): - print(response.text, flush=True, end="") - print() - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py deleted file mode 100644 index f268913b..00000000 --- a/llms/mlx_lm/convert.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import argparse - -from . import utils -from .utils import convert - -QUANT_RECIPES = [ - "mixed_2_6", - "mixed_3_6", -] - - -def quant_args(arg): - if arg not in QUANT_RECIPES: - raise argparse.ArgumentTypeError( - f"Invalid q-recipe {arg!r}. Choose from: {QUANT_RECIPES}" - ) - else: - return getattr(utils, arg) - - -def configure_parser() -> argparse.ArgumentParser: - """ - Configures and returns the argument parser for the script. - - Returns: - argparse.ArgumentParser: Configured argument parser. - """ - parser = argparse.ArgumentParser( - description="Convert Hugging Face model to MLX format" - ) - - parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") - parser.add_argument( - "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." - ) - parser.add_argument( - "-q", "--quantize", help="Generate a quantized model.", action="store_true" - ) - parser.add_argument( - "--q-group-size", help="Group size for quantization.", type=int, default=64 - ) - parser.add_argument( - "--q-bits", help="Bits per weight for quantization.", type=int, default=4 - ) - parser.add_argument( - "--quant-predicate", - help=f"Mixed-bit quantization recipe. Choices: {QUANT_RECIPES}", - type=quant_args, - required=False, - ) - parser.add_argument( - "--dtype", - help="Type to save the non-quantized parameters.", - type=str, - choices=["float16", "bfloat16", "float32"], - default="float16", - ) - parser.add_argument( - "--upload-repo", - help="The Hugging Face repo to upload the model to.", - type=str, - default=None, - ) - parser.add_argument( - "-d", - "--dequantize", - help="Dequantize a quantized model.", - action="store_true", - default=False, - ) - return parser - - -def main(): - parser = configure_parser() - args = parser.parse_args() - convert(**vars(args)) - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py deleted file mode 100644 index cd6de7ec..00000000 --- a/llms/mlx_lm/evaluate.py +++ /dev/null @@ -1,392 +0,0 @@ -# Copyright © 2024 Apple Inc. - -""" -Adapted from a PyTorch implementation by David Grangier -""" - -import argparse -import json -import logging -import os -from importlib.metadata import version -from pathlib import Path -from typing import Optional, Union - -import lm_eval -import mlx.core as mx -import mlx.nn as nn -import numpy as np -from lm_eval.api.model import LM -from lm_eval.api.registry import register_model -from tqdm import tqdm - -from .models.cache import make_prompt_cache -from .utils import load, stream_generate - -PAD = 0 - - -def _len_longest_common_prefix(a, b): - l = 0 - for item_a, item_b in zip(a, b): - if item_a != item_b: - break - l += 1 - return l - - -def _rstrip_until(s, untils): - """Limit a string to the first occurrence of any substring in untils.""" - l = len(s) - f = [s.find(u) for u in untils] - f = [l if x < 0 else x for x in f] - return s[: min(f)] - - -def _pad_inputs( - inputs, - maxlen, - genlen=0, - pad_left=False, - pad_multiple=32, - truncate=False, -): - # pad the prompts to the left with at least genlen tokens. - actual_maxlen = max(len(p) for p in inputs) + genlen - if actual_maxlen > maxlen: - if not truncate: - raise ValueError("Inputs are too long.") - else: # drop begining - actual_maxlen = maxlen - inputs = [p[max(0, len(p) - maxlen) :] for p in inputs] - if pad_multiple > 0: - maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple - maxlen *= pad_multiple - assert PAD == 0 - lr = np.array((1, 0) if pad_left else (0, 1)) - return np.stack( - [np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs], - axis=0, - ) - - -@register_model("mlxlm") -class MLXLM(LM): - def __init__( - self, - path_or_hf_repo: str, - batch_size: int = 16, - max_tokens: Optional[int] = None, - use_chat_template: Optional[bool] = None, - ) -> None: - super().__init__() - self._batch_size = batch_size - self._model, self.tokenizer = load(path_or_hf_repo) - self._max_tokens = max_tokens or self.tokenizer.model_max_length - self.use_chat_template = use_chat_template or ( - self.tokenizer.chat_template is not None - ) - - def _score_fn(self, inputs, tokenize=True, step_size=32): - if tokenize: - inputs = self._tokenize(inputs) - inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) - inputs = mx.array(inputs) - inputs, targets = inputs[..., :-1], inputs[..., 1:] - - cache = make_prompt_cache(self._model) - - mask = targets != PAD - - scores, is_greedy = [], [] - for i in range(0, inputs.shape[1], step_size): - logits = self._model(inputs[:, i : i + step_size], cache=cache) - - log_probs = nn.log_softmax(logits.astype(mx.float32)) - score = mx.take_along_axis( - log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 - )[..., 0] - ig = mask[:, i : i + step_size] * ( - targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) - ) - - mx.eval(score, ig) - mx.metal.clear_cache() - - is_greedy.append(ig) - scores.append(score) - - scores = mx.concatenate(scores, axis=1) - is_greedy = mx.concatenate(is_greedy, axis=1) - - return scores, mask.sum(axis=-1), is_greedy - - def _loglikelihood(self, texts, score_spans=None, tokenize=True): - # sort by length to get batches with little padding. - sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i])) - sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))] - sorted_spans = None - if score_spans is not None: - sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))] - - results = [] - for i in tqdm(range(0, len(sorted_inputs), self._batch_size)): - batch = sorted_inputs[i : i + self._batch_size] - scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize) - for j in range(len(batch)): - if sorted_spans is None: # full sequence score - mask = mx.arange(scores[j].shape[-1]) < length - score = (scores[j].astype(mx.float32) * mask).sum(axis=-1) - ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1) - else: # subsequence score - start, end = sorted_spans[i + j] - score = scores[j][start:end].astype(mx.float32).sum() - ig = is_greedy[j][start:end].astype(mx.int32).sum() - length = end - start - - results.append((score.item(), ig.item(), length)) - - # reorder the outputs - inv_sort = np.argsort(sorted_indices) - results = [results[inv_sort[i]] for i in range(len(results))] - - return results - - def _tokenize(self, texts): - return [ - tuple( - self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template) - ) - for t in texts - ] - - def loglikelihood(self, requests) -> list[tuple[float, bool]]: - """Compute log-likelihood of generating a continuation from a context. - Downstream tasks should attempt to use loglikelihood instead of other - LM calls whenever possible. - :param requests: list[Instance] - A list of Instance objects, with property `args` which returns a tuple (context, continuation). - `context: str` - Context string. Implementations of LM must be able to handle an - empty context string. - `continuation: str` - The continuation over which log likelihood will be calculated. If - there is a word boundary, the space should be in the continuation. - For example, context="hello" continuation=" world" is correct. - :return: list[tuple[float, bool]] - A list of pairs (logprob, isgreedy) - `logprob: float` - The log probability of `continuation`. - `isgreedy`: - Whether `continuation` would be generated by greedy sampling from `context`. - """ - logging.info("Estimating loglikelihood for %d pairs." % len(requests)) - - # tokenize prefix and prefix + completion for all requests. - tokenized = self._tokenize( - [t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]] - ) - - # max length (prefix + completion) and longest common prefix per question. - length_stats = {} - for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): - max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8)) - length_stats[prefix] = ( - max(max_completed_l, len(completed)), - min(min_prefix_l, _len_longest_common_prefix(prefix, completed)), - ) - - # truncate requests for completed sequences longer than model context. - shortened = [] - completion_spans = [] - long_completions = 0 - for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): - max_completed_l, prefix_l = length_stats[prefix] - # compute truncation length - truncation = max(0, max_completed_l - self._max_tokens - 1) - prefix_l = prefix_l - truncation - if prefix_l <= 0: - # completion too long, prefix is eliminated for some requests. - long_completions += 1 - truncation = max(0, len(completed) - self._max_tokens - 1) - prefix_l = 1 - # truncate the completed sequence - completed = completed[truncation:] - shortened.append(completed) - # scores do not include initial bos, substract 1 to span bounds - completion_spans.append((prefix_l - 1, len(completed) - 1)) - - if long_completions > 0: - logging.info( - f"Prefix eliminated for {long_completions} requests with " - + "completion longer than context." - ) - - # model scoring, returns num_requests x (logp, is_greedy, length). - results = self._loglikelihood( - shortened, - score_spans=completion_spans, - tokenize=False, - ) - return [(r[0], r[1] == r[2]) for r in results] - - tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name - apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template - - def loglikelihood_rolling(self, requests) -> list[float]: - """Compute full log-likelihood of a string, with no truncation, for perplexity computation - - We will use the full max context length of the model. - - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to - the max context length. - - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations - which may simply concatenate multiple documents together. - - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into - multiple chunks, the last input will still a full-sized context. - Example: - Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] - Prefix: EOT - Max context length: 4 - Resulting input/prediction pairs: - INPUT: EOT 0 1 2 - PRED: 0 1 2 3 - INPUT: 3 4 5 6 - PRED: 4 5 6 7 - INPUT: 5 6 7 8 - PRED: 8 9 - Observe that: - 1. Each token is predicted exactly once - 2. For the last pair, we provide the full context, but only score the last two tokens - :param requests: list[Instance] - A list of Instance objects with property `args` which returns a tuple (context,). - string: str - String for which we are computing overall loglikelihood - :return: list[tuple[float]] - A list of tuples (logprob,) - logprob: float - The log probability of `context` conditioned on the EOT token. - """ - logging.info( - "Estimating loglikelihood rolling for %d sequences." % len(requests) - ) - inputs = [req.args[0] for req in requests] - return [t[0] for t in self._loglikelihood(inputs)] - - def generate_until(self, requests) -> list[str]: - """Generate greedily until a stopping sequence - :param requests: list[Instance] - A list of Instance objects with property `args` which returns a tuple (context, until). - context: str - Context string - until: [str] - The string sequences to generate until. These string sequences - may each span across multiple tokens, or may be part of one token. - :return: list[str] - A list of strings continuation - continuation: str - The generated continuation. - """ - logging.info("Generating continuation for %d sequences." % len(requests)) - contexts, options = zip(*[req.args for req in requests]) - # contrary to the doc the second element of the tuple contains - # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} - completions = [] - - for context, opt in tqdm(zip(contexts, options), total=len(contexts)): - until = opt["until"] - context = self.tokenizer.encode( - context, add_special_tokens=not self.use_chat_template - ) - max_tokens = min( - opt.get("max_gen_tokens", self._max_tokens), - self.tokenizer.model_max_length - len(context), - ) - text = "" - for response in stream_generate( - self._model, self.tokenizer, prompt=context, max_tokens=max_tokens - ): - text += response.text - if any(u in text for u in until): - text = _rstrip_until(text, until) - completions.append(text) - break - else: - completions.append(text) - return completions - - -def main(): - parser = argparse.ArgumentParser( - "Evaluate an MLX model using lm-evaluation-harness." - ) - parser.add_argument("--model", help="Model to evaluate", required=True) - parser.add_argument("--tasks", nargs="+", required=True) - parser.add_argument( - "--output-dir", default=".", help="Output directory for result files." - ) - parser.add_argument("--batch-size", type=int, default=16, help="Batch size") - parser.add_argument("--num-shots", type=int, default=0, help="Number of shots") - parser.add_argument( - "--max-tokens", - type=int, - help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", - ) - parser.add_argument( - "--limit", - default=100, - help="Limit the number of examples per task.", - type=int, - ) - parser.add_argument("--seed", type=int, default=123, help="Random seed.") - parser.add_argument( - "--fewshot-as-multiturn", - action="store_true", - help="Whether to provide the fewshot examples as a multiturn " - "conversation or a single user turn.", - default=False, - ) - parser.add_argument( - "--apply-chat-template", - action=argparse.BooleanOptionalAction, - help="Specifies whether to apply a chat template to the prompt. If " - "the model has a chat template, this defaults to `True`, " - "otherwise `False`.", - default=None, - ) - args = parser.parse_args() - - output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - # Silence tokenizer warnings - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - mx.random.seed(args.seed) - - lm = MLXLM( - args.model, - batch_size=args.batch_size, - max_tokens=args.max_tokens, - use_chat_template=args.apply_chat_template, - ) - results = lm_eval.simple_evaluate( - model=lm, - tasks=args.tasks, - fewshot_as_multiturn=args.fewshot_as_multiturn, - apply_chat_template=lm.use_chat_template, - num_fewshot=args.num_shots, - limit=args.limit, - random_seed=args.seed, - numpy_random_seed=args.seed, - torch_random_seed=args.seed, - fewshot_random_seed=args.seed, - ) - - model_name = args.model.replace("/", "_") - task_names = "_".join(args.tasks) - ver = version("lm_eval") - filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json" - output_path = output_dir / filename - output_path.write_text(json.dumps(results["results"], indent=4)) - print("Results:") - for result in results["results"].values(): - print(json.dumps(result, indent=4)) diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py deleted file mode 100644 index dcd90b67..00000000 --- a/llms/mlx_lm/examples/chat.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright © 2024 Apple Inc. - -""" -An example of a multi-turn chat with prompt caching. -""" - -from mlx_lm import generate, load -from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache - -model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") - -# Make the initial prompt cache for the model -prompt_cache = make_prompt_cache(model) - -# User turn -prompt = "Hi my name is ." -messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - -# Assistant response -response = generate( - model, - tokenizer, - prompt=prompt, - verbose=True, - prompt_cache=prompt_cache, -) - -# User turn -prompt = "What's my name?" -messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - -# Assistant response -response = generate( - model, - tokenizer, - prompt=prompt, - verbose=True, - prompt_cache=prompt_cache, -) - -# Save the prompt cache to disk to reuse it at a later time -save_prompt_cache("mistral_prompt.safetensors", prompt_cache) - -# Load the prompt cache from disk -prompt_cache = load_prompt_cache("mistral_prompt.safetensors") diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py deleted file mode 100644 index 41eaf1da..00000000 --- a/llms/mlx_lm/examples/generate_response.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright © 2024 Apple Inc. - -from mlx_lm import generate, load - -# Specify the checkpoint -checkpoint = "mistralai/Mistral-7B-Instruct-v0.3" - -# Load the corresponding model and tokenizer -model, tokenizer = load(path_or_hf_repo=checkpoint) - -# Specify the prompt and conversation history -prompt = "Why is the sky blue?" -conversation = [{"role": "user", "content": prompt}] - -# Transform the prompt into the chat template -prompt = tokenizer.apply_chat_template( - conversation=conversation, add_generation_prompt=True -) - -# Specify the maximum number of tokens -max_tokens = 1_000 - -# Specify if tokens and timing information will be printed -verbose = True - -# Generate a response with the specified settings -response = generate( - model=model, - tokenizer=tokenizer, - prompt=prompt, - max_tokens=max_tokens, - verbose=verbose, -) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml deleted file mode 100644 index 36bc1dff..00000000 --- a/llms/mlx_lm/examples/lora_config.yaml +++ /dev/null @@ -1,89 +0,0 @@ -# The path to the local model directory or Hugging Face repo. -model: "mlx_model" - -# Whether or not to train (boolean) -train: true - -# The fine-tuning method: "lora", "dora", or "full". -fine_tune_type: lora - -# The Optimizer with its possible inputs -optimizer: adamw -# optimizer_config: -# adamw: -# betas: [0.9, 0.98] -# eps: 1e-6 -# weight_decay: 0.05 -# bias_correction: true - -# Directory with {train, valid, test}.jsonl files -data: "/path/to/training/data" - -# The PRNG seed -seed: 0 - -# Number of layers to fine-tune -num_layers: 16 - -# Minibatch size. -batch_size: 4 - -# Iterations to train for. -iters: 1000 - -# Number of validation batches, -1 uses the entire validation set. -val_batches: 25 - -# Adam learning rate. -learning_rate: 1e-5 - -# Number of training steps between loss reporting. -steps_per_report: 10 - -# Number of training steps between validations. -steps_per_eval: 200 - -# Load path to resume training with the given adapter weights. -resume_adapter_file: null - -# Save/load path for the trained adapter weights. -adapter_path: "adapters" - -# Save the model every N iterations. -save_every: 100 - -# Evaluate on the test set after training -test: false - -# Number of test set batches, -1 uses the entire test set. -test_batches: 100 - -# Maximum sequence length. -max_seq_length: 2048 - -# Use gradient checkpointing to reduce memory use. -grad_checkpoint: false - -# LoRA parameters can only be specified in a config file -lora_parameters: - # The layer keys to apply LoRA to. - # These will be applied for the last lora_layers - keys: ["self_attn.q_proj", "self_attn.v_proj"] - rank: 8 - scale: 20.0 - dropout: 0.0 - -# Schedule can only be specified in a config file, uncomment to use. -#lr_schedule: -# name: cosine_decay -# warmup: 100 # 0 for no warmup -# warmup_init: 1e-7 # 0 if not specified -# arguments: [1e-5, 1000, 1e-7] # passed to scheduler - -#hf_dataset: -# name: "billsum" -# train_split: "train[:1000]" -# valid_split: "train[-100:]" -# prompt_feature: "text" -# completion_feature: "summary" - diff --git a/llms/mlx_lm/examples/merge_config.yaml b/llms/mlx_lm/examples/merge_config.yaml deleted file mode 100644 index 98701e55..00000000 --- a/llms/mlx_lm/examples/merge_config.yaml +++ /dev/null @@ -1,11 +0,0 @@ -models: - - OpenPipe/mistral-ft-optimized-1218 - - mlabonne/NeuralHermes-2.5-Mistral-7B -method: slerp -parameters: - t: - - filter: self_attn - value: [0, 0.5, 0.3, 0.7, 1] - - filter: mlp - value: [1, 0.5, 0.7, 0.3, 0] - - value: 0.5 diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py deleted file mode 100644 index 1e4fb445..00000000 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright © 2024 Apple Inc. - -""" -Run with: - -``` -mlx.launch \ - --hostfile /path/to/hosts.txt \ - --backend mpi \ - /path/to/pipeline_generate.py \ - --prompt "hello world" -``` - -Make sure you can run MLX over MPI on two hosts. For more information see the -documentation: - -https://ml-explore.github.io/mlx/build/html/usage/distributed.html). -""" - -import argparse -import json -from pathlib import Path - -import mlx.core as mx -from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten -from mlx_lm import load, stream_generate -from mlx_lm.utils import load_model, load_tokenizer - - -def download(repo: str, allow_patterns: list[str]) -> Path: - return Path( - snapshot_download( - repo, - allow_patterns=allow_patterns, - ) - ) - - -def shard_and_load(repo): - # Get model path with everything but weight safetensors - model_path = download( - args.model, - allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], - ) - - # Lazy load and shard model to figure out - # which weights we need - model, _ = load_model(model_path, lazy=True, strict=False) - - group = mx.distributed.init(backend="mpi") - rank = group.rank() - model.model.pipeline(group) - - # Figure out which files we need for the local shard - with open(model_path / "model.safetensors.index.json", "r") as fid: - weight_index = json.load(fid)["weight_map"] - - local_files = set() - for k, _ in tree_flatten(model.parameters()): - local_files.add(weight_index[k]) - - # Download weights for local shard - download(args.model, allow_patterns=local_files) - - # Load and shard the model, and load the weights - tokenizer = load_tokenizer(model_path) - model, _ = load_model(model_path, lazy=True, strict=False) - model.model.pipeline(group) - mx.eval(model.parameters()) - - # Synchronize processes before generation to avoid timeout if downloading - # model for the first time. - mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) - return model, tokenizer - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="LLM pipelined inference example") - parser.add_argument( - "--model", - default="mlx-community/DeepSeek-R1-3bit", - help="HF repo or path to local model.", - ) - parser.add_argument( - "--prompt", - "-p", - default="Write a quicksort in C++.", - help="Message to be processed by the model ('-' reads from stdin)", - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=256, - help="Maximum number of tokens to generate", - ) - args = parser.parse_args() - - group = mx.distributed.init(backend="mpi") - rank = group.rank() - - def rprint(*args, **kwargs): - if rank == 0: - print(*args, **kwargs) - - model, tokenizer = shard_and_load(args.model) - - messages = [{"role": "user", "content": args.prompt}] - prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - - for response in stream_generate( - model, tokenizer, prompt, max_tokens=args.max_tokens - ): - rprint(response.text, end="", flush=True) - - rprint() - rprint("=" * 10) - rprint( - f"Prompt: {response.prompt_tokens} tokens, " - f"{response.prompt_tps:.3f} tokens-per-sec" - ) - rprint( - f"Generation: {response.generation_tokens} tokens, " - f"{response.generation_tps:.3f} tokens-per-sec" - ) - rprint(f"Peak memory: {response.peak_memory:.3f} GB") diff --git a/llms/mlx_lm/examples/tool_use.py b/llms/mlx_lm/examples/tool_use.py deleted file mode 100644 index 624b9e5b..00000000 --- a/llms/mlx_lm/examples/tool_use.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright © 2025 Apple Inc. - -import json - -from mlx_lm import generate, load -from mlx_lm.models.cache import make_prompt_cache - -# Specify the checkpoint -checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit" - -# Load the corresponding model and tokenizer -model, tokenizer = load(path_or_hf_repo=checkpoint) - - -# An example tool, make sure to include a docstring and type hints -def multiply(a: float, b: float): - """ - A function that multiplies two numbers - - Args: - a: The first number to multiply - b: The second number to multiply - """ - return a * b - - -tools = {"multiply": multiply} - -# Specify the prompt and conversation history -prompt = "Multiply 12234585 and 48838483920." -messages = [{"role": "user", "content": prompt}] - -prompt = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tools=list(tools.values()) -) - -prompt_cache = make_prompt_cache(model) - -# Generate the initial tool call: -response = generate( - model=model, - tokenizer=tokenizer, - prompt=prompt, - max_tokens=2048, - verbose=True, - prompt_cache=prompt_cache, -) - -# Parse the tool call: -# (Note, the tool call format is model specific) -tool_open = "" -tool_close = "" -start_tool = response.find(tool_open) + len(tool_open) -end_tool = response.find(tool_close) -tool_call = json.loads(response[start_tool:end_tool].strip()) -tool_result = tools[tool_call["name"]](**tool_call["arguments"]) - -# Put the tool result in the prompt -messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}] -prompt = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, -) - -# Generate the final response: -response = generate( - model=model, - tokenizer=tokenizer, - prompt=prompt, - max_tokens=2048, - verbose=True, - prompt_cache=prompt_cache, -) diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py deleted file mode 100644 index b0c46a74..00000000 --- a/llms/mlx_lm/fuse.py +++ /dev/null @@ -1,130 +0,0 @@ -import argparse -import glob -import shutil -from pathlib import Path - -from mlx.utils import tree_flatten, tree_unflatten - -from .gguf import convert_to_gguf -from .tuner.dora import DoRAEmbedding, DoRALinear -from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear -from .tuner.utils import dequantize, load_adapters -from .utils import ( - fetch_from_hub, - get_model_path, - save_config, - save_weights, - upload_to_hub, -) - - -def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Fuse fine-tuned adapters into the base model." - ) - parser.add_argument( - "--model", - default="mlx_model", - help="The path to the local model directory or Hugging Face repo.", - ) - parser.add_argument( - "--save-path", - default="fused_model", - help="The path to save the fused model.", - ) - parser.add_argument( - "--adapter-path", - type=str, - default="adapters", - help="Path to the trained adapter weights and config.", - ) - parser.add_argument( - "--hf-path", - type=str, - default=None, - help="Path to the original Hugging Face model. Required for upload if --model is a local directory.", - ) - parser.add_argument( - "--upload-repo", - help="The Hugging Face repo to upload the model to.", - type=str, - default=None, - ) - parser.add_argument( - "--de-quantize", - help="Generate a de-quantized model.", - action="store_true", - ) - parser.add_argument( - "--export-gguf", - help="Export model weights in GGUF format.", - action="store_true", - ) - parser.add_argument( - "--gguf-path", - help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.", - default="ggml-model-f16.gguf", - type=str, - ) - return parser.parse_args() - - -def main() -> None: - print("Loading pretrained model") - args = parse_arguments() - - model_path = get_model_path(args.model) - model, config, tokenizer = fetch_from_hub(model_path) - - model.freeze() - model = load_adapters(model, args.adapter_path) - - fused_linears = [ - (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse") - ] - - if fused_linears: - model.update_modules(tree_unflatten(fused_linears)) - - if args.de_quantize: - print("De-quantizing model") - model = dequantize(model) - - weights = dict(tree_flatten(model.parameters())) - - save_path = Path(args.save_path) - - save_weights(save_path, weights) - - py_files = glob.glob(str(model_path / "*.py")) - for file in py_files: - shutil.copy(file, save_path) - - tokenizer.save_pretrained(save_path) - - if args.de_quantize: - config.pop("quantization", None) - - save_config(config, config_path=save_path / "config.json") - - if args.export_gguf: - model_type = config["model_type"] - if model_type not in ["llama", "mixtral", "mistral"]: - raise ValueError( - f"Model type {model_type} not supported for GGUF conversion." - ) - convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path)) - - if args.upload_repo is not None: - hf_path = args.hf_path or ( - args.model if not Path(args.model).exists() else None - ) - if hf_path is None: - raise ValueError( - "Must provide original Hugging Face repo to upload local model." - ) - upload_to_hub(args.save_path, args.upload_repo, hf_path) - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py deleted file mode 100644 index 7d58da82..00000000 --- a/llms/mlx_lm/generate.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import argparse -import json -import sys - -import mlx.core as mx - -from .models.cache import QuantizedKVCache, load_prompt_cache -from .sample_utils import make_sampler -from .utils import generate, load - -DEFAULT_PROMPT = "hello" -DEFAULT_MAX_TOKENS = 100 -DEFAULT_TEMP = 0.0 -DEFAULT_TOP_P = 1.0 -DEFAULT_MIN_P = 0.0 -DEFAULT_MIN_TOKENS_TO_KEEP = 1 -DEFAULT_SEED = None -DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" -DEFAULT_QUANTIZED_KV_START = 5000 - - -def str2bool(string): - return string.lower() not in ["false", "f"] - - -def setup_arg_parser(): - """Set up and return the argument parser.""" - parser = argparse.ArgumentParser(description="LLM inference script") - parser.add_argument( - "--model", - type=str, - help=( - "The path to the local model directory or Hugging Face repo. " - f"If no model is specified, then {DEFAULT_MODEL} is used." - ), - default=None, - ) - parser.add_argument( - "--adapter-path", - type=str, - help="Optional path for the trained adapter weights and config.", - ) - parser.add_argument( - "--extra-eos-token", - type=str, - default=(), - nargs="+", - help="Add tokens in the list of eos tokens that stop generation.", - ) - parser.add_argument( - "--system-prompt", - default=None, - help="System prompt to be used for the chat template", - ) - parser.add_argument( - "--prompt", - "-p", - default=DEFAULT_PROMPT, - help="Message to be processed by the model ('-' reads from stdin)", - ) - parser.add_argument( - "--prefill-response", - default=None, - help="Prefill response to be used for the chat template", - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=DEFAULT_MAX_TOKENS, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" - ) - parser.add_argument( - "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" - ) - parser.add_argument( - "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p" - ) - parser.add_argument( - "--min-tokens-to-keep", - type=int, - default=DEFAULT_MIN_TOKENS_TO_KEEP, - help="Minimum tokens to keep for min-p sampling.", - ) - parser.add_argument( - "--seed", - type=int, - default=DEFAULT_SEED, - help="PRNG seed", - ) - parser.add_argument( - "--ignore-chat-template", - action="store_true", - help="Use the raw prompt without the tokenizer's chat template.", - ) - parser.add_argument( - "--use-default-chat-template", - action="store_true", - help="Use the default chat template", - ) - parser.add_argument( - "--chat-template-config", - help="Additional config for `apply_chat_template`. Should be a dictionary of" - " string keys to values represented as a JSON decodable string.", - default=None, - ) - parser.add_argument( - "--verbose", - type=str2bool, - default=True, - help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", - ) - parser.add_argument( - "--max-kv-size", - type=int, - help="Set the maximum key-value cache size", - default=None, - ) - parser.add_argument( - "--prompt-cache-file", - type=str, - default=None, - help="A file containing saved KV caches to avoid recomputing them", - ) - parser.add_argument( - "--kv-bits", - type=int, - help="Number of bits for KV cache quantization. " - "Defaults to no quantization.", - default=None, - ) - parser.add_argument( - "--kv-group-size", - type=int, - help="Group size for KV cache quantization.", - default=64, - ) - parser.add_argument( - "--quantized-kv-start", - help="When --kv-bits is set, start quantizing the KV cache " - "from this step onwards.", - type=int, - default=DEFAULT_QUANTIZED_KV_START, - ) - parser.add_argument( - "--draft-model", - type=str, - help="A model to be used for speculative decoding.", - default=None, - ) - parser.add_argument( - "--num-draft-tokens", - type=int, - help="Number of tokens to draft when using speculative decoding.", - default=3, - ) - return parser - - -def main(): - parser = setup_arg_parser() - args = parser.parse_args() - - if args.seed is not None: - mx.random.seed(args.seed) - - # Load the prompt cache and metadata if a cache file is provided - using_cache = args.prompt_cache_file is not None - if using_cache: - prompt_cache, metadata = load_prompt_cache( - args.prompt_cache_file, - return_metadata=True, - ) - if isinstance(prompt_cache[0], QuantizedKVCache): - if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits: - raise ValueError( - "--kv-bits does not match the kv cache loaded from --prompt-cache-file." - ) - if args.kv_group_size != prompt_cache[0].group_size: - raise ValueError( - "--kv-group-size does not match the kv cache loaded from --prompt-cache-file." - ) - - # Building tokenizer_config - tokenizer_config = ( - {} if not using_cache else json.loads(metadata["tokenizer_config"]) - ) - tokenizer_config["trust_remote_code"] = True - - model_path = args.model - if using_cache: - if model_path is None: - model_path = metadata["model"] - elif model_path != metadata["model"]: - raise ValueError( - f"Providing a different model ({model_path}) than that " - f"used to create the prompt cache ({metadata['model']}) " - "is an error." - ) - model_path = model_path or DEFAULT_MODEL - - model, tokenizer = load( - model_path, - adapter_path=args.adapter_path, - tokenizer_config=tokenizer_config, - ) - for eos_token in args.extra_eos_token: - tokenizer.add_eos_token(eos_token) - - template_kwargs = {} - if args.chat_template_config is not None: - template_kwargs = json.loads(args.chat_template_config) - - if args.use_default_chat_template: - if tokenizer.chat_template is None: - tokenizer.chat_template = tokenizer.default_chat_template - elif using_cache: - tokenizer.chat_template = json.loads(metadata["chat_template"]) - - prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") - prompt = sys.stdin.read() if prompt == "-" else prompt - if not args.ignore_chat_template and tokenizer.chat_template is not None: - if args.system_prompt is not None: - messages = [{"role": "system", "content": args.system_prompt}] - else: - messages = [] - messages.append({"role": "user", "content": prompt}) - - has_prefill = args.prefill_response is not None - if has_prefill: - messages.append({"role": "assistant", "content": args.prefill_response}) - prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - continue_final_message=has_prefill, - add_generation_prompt=not has_prefill, - **template_kwargs, - ) - - # Treat the prompt as a suffix assuming that the prefix is in the - # stored kv cache. - if using_cache: - messages[-1]["content"] = "" - test_prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - continue_final_message=has_prefill, - add_generation_prompt=not has_prefill, - ) - prompt = prompt[test_prompt.index("") :] - prompt = tokenizer.encode(prompt, add_special_tokens=False) - else: - prompt = tokenizer.encode(prompt) - - if args.draft_model is not None: - draft_model, draft_tokenizer = load(args.draft_model) - if draft_tokenizer.vocab_size != tokenizer.vocab_size: - raise ValueError("Draft model tokenizer does not match model tokenizer.") - else: - draft_model = None - sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) - response = generate( - model, - tokenizer, - prompt, - max_tokens=args.max_tokens, - verbose=args.verbose, - sampler=sampler, - max_kv_size=args.max_kv_size, - prompt_cache=prompt_cache if using_cache else None, - kv_bits=args.kv_bits, - kv_group_size=args.kv_group_size, - quantized_kv_start=args.quantized_kv_start, - draft_model=draft_model, - num_draft_tokens=args.num_draft_tokens, - ) - if not args.verbose: - print(response) - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py deleted file mode 100644 index 241ac35a..00000000 --- a/llms/mlx_lm/gguf.py +++ /dev/null @@ -1,314 +0,0 @@ -import re -from enum import IntEnum -from pathlib import Path -from typing import Iterable, Optional, Set, Tuple, Union - -import mlx.core as mx -from transformers import AutoTokenizer - - -class TokenType(IntEnum): - NORMAL = 1 - UNKNOWN = 2 - CONTROL = 3 - USER_DEFINED = 4 - UNUSED = 5 - BYTE = 6 - - -class GGMLFileType(IntEnum): - GGML_TYPE_F16 = 1 - - -# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455 -class HfVocab: - def __init__( - self, - fname_tokenizer: Path, - fname_added_tokens: Optional[Union[Path, None]] = None, - ) -> None: - self.tokenizer = AutoTokenizer.from_pretrained( - fname_tokenizer, - cache_dir=fname_tokenizer, - local_files_only=True, - ) - self.added_tokens_list = [] - self.added_tokens_dict = dict() - self.added_tokens_ids = set() - for tok, tokidx in sorted( - self.tokenizer.get_added_vocab().items(), key=lambda x: x[1] - ): - if tokidx >= self.tokenizer.vocab_size: - self.added_tokens_list.append(tok) - self.added_tokens_dict[tok] = tokidx - self.added_tokens_ids.add(tokidx) - self.specials = { - tok: self.tokenizer.get_vocab()[tok] - for tok in self.tokenizer.all_special_tokens - } - self.special_ids = set(self.tokenizer.all_special_ids) - self.vocab_size_base = self.tokenizer.vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens - - def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: - reverse_vocab = { - id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() - } - for token_id in range(self.vocab_size_base): - if token_id in self.added_tokens_ids: - continue - token_text = reverse_vocab[token_id] - yield token_text, self.get_token_score(token_id), self.get_token_type( - token_id, token_text, self.special_ids - ) - - def get_token_type( - self, token_id: int, token_text: bytes, special_ids: Set[int] - ) -> TokenType: - if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text): - return TokenType.BYTE - return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL - - def get_token_score(self, token_id: int) -> float: - return -1000.0 - - def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: - for text in self.added_tokens_list: - if text in self.specials: - toktype = self.get_token_type(self.specials[text], "", self.special_ids) - score = self.get_token_score(self.specials[text]) - else: - toktype = TokenType.USER_DEFINED - score = -1000.0 - yield text, score, toktype - - def has_newline_token(self): - return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab - - def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: - yield from self.hf_tokens() - yield from self.added_tokens() - - def __repr__(self) -> str: - return f"" - - @staticmethod - def load(path: Path) -> "HfVocab": - added_tokens_path = path.parent / "added_tokens.json" - return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None) - - -def translate_weight_names(name): - name = name.replace("model.layers.", "blk.") - # for mixtral gate - name = name.replace("block_sparse_moe.gate", "ffn_gate_inp") - # for mixtral experts ffns - pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight" - replacement = r"ffn_gate.\1.weight" - name = re.sub(pattern, replacement, name) - pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight" - replacement = r"ffn_down.\1.weight" - name = re.sub(pattern, replacement, name) - pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight" - replacement = r"ffn_up.\1.weight" - name = re.sub(pattern, replacement, name) - - name = name.replace("mlp.gate_proj", "ffn_gate") - name = name.replace("mlp.down_proj", "ffn_down") - name = name.replace("mlp.up_proj", "ffn_up") - name = name.replace("self_attn.q_proj", "attn_q") - name = name.replace("self_attn.k_proj", "attn_k") - name = name.replace("self_attn.v_proj", "attn_v") - name = name.replace("self_attn.o_proj", "attn_output") - name = name.replace("input_layernorm", "attn_norm") - name = name.replace("post_attention_layernorm", "ffn_norm") - name = name.replace("model.embed_tokens", "token_embd") - name = name.replace("model.norm", "output_norm") - name = name.replace("lm_head", "output") - return name - - -def permute_weights(weights, n_head, n_head_kv=None): - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv - reshaped = weights.reshape( - n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] - ) - swapped = reshaped.swapaxes(1, 2) - final_shape = weights.shape - return swapped.reshape(final_shape) - - -def prepare_metadata(config, vocab): - metadata = { - "general.name": "llama", - "llama.context_length": ( - mx.array(config["max_position_embeddings"], dtype=mx.uint32) - if config.get("max_position_embeddings") is not None - else None - ), - "llama.embedding_length": ( - mx.array(config["hidden_size"], dtype=mx.uint32) - if config.get("hidden_size") is not None - else None - ), - "llama.block_count": ( - mx.array(config["num_hidden_layers"], dtype=mx.uint32) - if config.get("num_hidden_layers") is not None - else None - ), - "llama.feed_forward_length": ( - mx.array(config["intermediate_size"], dtype=mx.uint32) - if config.get("intermediate_size") is not None - else None - ), - "llama.rope.dimension_count": ( - mx.array( - config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32 - ) - if config.get("hidden_size") is not None - and config.get("num_attention_heads") is not None - else None - ), - "llama.attention.head_count": ( - mx.array(config["num_attention_heads"], dtype=mx.uint32) - if config.get("num_attention_heads") is not None - else None - ), - "llama.attention.head_count_kv": ( - mx.array( - config.get("num_key_value_heads", config["num_attention_heads"]), - dtype=mx.uint32, - ) - if config.get("num_attention_heads") is not None - else None - ), - "llama.expert_count": ( - mx.array(config.get("num_local_experts", None), dtype=mx.uint32) - if config.get("num_local_experts") is not None - else None - ), - "llama.expert_used_count": ( - mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32) - if config.get("num_experts_per_tok") is not None - else None - ), - "llama.attention.layer_norm_rms_epsilon": ( - mx.array(config.get("rms_norm_eps", 1e-05)) - if config.get("rms_norm_eps") is not None - else None - ), - "llama.rope.freq_base": ( - mx.array(config.get("rope_theta", 10000), dtype=mx.float32) - if config.get("rope_theta") is not None - else None - ), - } - - rope_scaling = config.get("rope_scaling") - if rope_scaling is not None and (typ := rope_scaling.get("type")): - rope_factor = rope_scaling.get("factor") - f_rope_scale = rope_factor - if typ == "linear": - rope_scaling_type = "linear" - metadata["llama.rope.scaling.type"] = rope_scaling_type - metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale) - - metadata["general.file_type"] = mx.array( - GGMLFileType.GGML_TYPE_F16.value, - dtype=mx.uint32, - ) - metadata["general.quantization_version"] = mx.array( - GGMLFileType.GGML_TYPE_F16.value, - dtype=mx.uint32, - ) - metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1] - metadata["general.architecture"] = "llama" - metadata["general.alignment"] = mx.array(32, dtype=mx.uint32) - - # add metadata for vocab - metadata["tokenizer.ggml.model"] = "llama" - tokens = [] - scores = [] - toktypes = [] - for text, score, toktype in vocab.all_tokens(): - tokens.append(text) - scores.append(score) - toktypes.append(toktype.value) - assert len(tokens) == vocab.vocab_size - metadata["tokenizer.ggml.tokens"] = tokens - metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32) - metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32) - if vocab.tokenizer.bos_token_id is not None: - metadata["tokenizer.ggml.bos_token_id"] = mx.array( - vocab.tokenizer.bos_token_id, dtype=mx.uint32 - ) - if vocab.tokenizer.eos_token_id is not None: - metadata["tokenizer.ggml.eos_token_id"] = mx.array( - vocab.tokenizer.eos_token_id, dtype=mx.uint32 - ) - if vocab.tokenizer.unk_token_id is not None: - metadata["tokenizer.ggml.unknown_token_id"] = mx.array( - vocab.tokenizer.unk_token_id, dtype=mx.uint32 - ) - - metadata = {k: v for k, v in metadata.items() if v is not None} - return metadata - - -def convert_to_gguf( - model_path: Union[str, Path], - weights: dict, - config: dict, - output_file_path: str, -): - if isinstance(model_path, str): - model_path = Path(model_path) - - quantization = config.get("quantization", None) - if quantization: - raise NotImplementedError( - "Conversion of quantized models is not yet supported." - ) - print("Converting to GGUF format") - # https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention - weights = { - k: ( - permute_weights( - v, config["num_attention_heads"], config["num_attention_heads"] - ) - if "self_attn.q_proj.weight" in k - else ( - permute_weights( - v, config["num_attention_heads"], config["num_key_value_heads"] - ) - if "self_attn.k_proj.weight" in k - else v - ) - ) - for k, v in weights.items() - } - - # rename weights for gguf format - weights = {translate_weight_names(k): v for k, v in weights.items()} - - if not (model_path / "tokenizer.json").exists(): - raise ValueError("Tokenizer json not found") - - vocab = HfVocab.load(model_path) - metadata = prepare_metadata(config, vocab) - - weights = { - k: ( - v.astype(mx.float32).astype(mx.float16) - if v.dtype == mx.bfloat16 - else v.astype(mx.float32) if "norm" in k else v - ) - for k, v in weights.items() - } - - output_file_path = output_file_path - mx.save_gguf(output_file_path, weights, metadata) - print(f"Converted GGUF model saved as: {output_file_path}") diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py deleted file mode 100644 index 042b40e2..00000000 --- a/llms/mlx_lm/lora.py +++ /dev/null @@ -1,335 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import argparse -import math -import os -import re -import types -from pathlib import Path - -import mlx.nn as nn -import mlx.optimizers as optim -import numpy as np -import yaml - -from .tokenizer_utils import TokenizerWrapper -from .tuner.datasets import load_dataset -from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train -from .tuner.utils import ( - build_schedule, - linear_to_lora_layers, - load_adapters, - print_trainable_parameters, -) -from .utils import load, save_config - -yaml_loader = yaml.SafeLoader -yaml_loader.add_implicit_resolver( - "tag:yaml.org,2002:float", - re.compile( - """^(?: - [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? - |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) - |\\.[0-9_]+(?:[eE][-+][0-9]+)? - |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* - |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$""", - re.X, - ), - list("-+0123456789."), -) - -CONFIG_DEFAULTS = { - "model": "mlx_model", - "train": False, - "fine_tune_type": "lora", - "optimizer": "adam", - "optimizer_config": { - "adam": {}, - "adamw": {}, - }, - "data": "data/", - "seed": 0, - "num_layers": 16, - "batch_size": 4, - "iters": 1000, - "val_batches": 25, - "learning_rate": 1e-5, - "steps_per_report": 10, - "steps_per_eval": 200, - "resume_adapter_file": None, - "adapter_path": "adapters", - "save_every": 100, - "test": False, - "test_batches": 500, - "max_seq_length": 2048, - "config": None, - "grad_checkpoint": False, - "lr_schedule": None, - "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, - "mask_prompt": False, -} - - -def build_parser(): - parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") - parser.add_argument( - "--model", - type=str, - help="The path to the local model directory or Hugging Face repo.", - ) - - # Training args - parser.add_argument( - "--train", - action="store_true", - help="Do training", - default=None, - ) - parser.add_argument( - "--data", - type=str, - help=( - "Directory with {train, valid, test}.jsonl files or the name " - "of a Hugging Face dataset (e.g., 'mlx-community/wikisql')" - ), - ) - parser.add_argument( - "--fine-tune-type", - type=str, - choices=["lora", "dora", "full"], - help="Type of fine-tuning to perform: lora, dora, or full.", - ) - parser.add_argument( - "--optimizer", - type=str, - choices=["adam", "adamw"], - default=None, - help="Optimizer to use for training: adam or adamw", - ) - parser.add_argument( - "--mask-prompt", - action="store_true", - help="Mask the prompt in the loss when training", - default=None, - ) - parser.add_argument( - "--num-layers", - type=int, - help="Number of layers to fine-tune. Default is 16, use -1 for all.", - ) - parser.add_argument("--batch-size", type=int, help="Minibatch size.") - parser.add_argument("--iters", type=int, help="Iterations to train for.") - parser.add_argument( - "--val-batches", - type=int, - help="Number of validation batches, -1 uses the entire validation set.", - ) - parser.add_argument("--learning-rate", type=float, help="Adam learning rate.") - parser.add_argument( - "--steps-per-report", - type=int, - help="Number of training steps between loss reporting.", - ) - parser.add_argument( - "--steps-per-eval", - type=int, - help="Number of training steps between validations.", - ) - parser.add_argument( - "--resume-adapter-file", - type=str, - help="Load path to resume training from the given fine-tuned weights.", - ) - parser.add_argument( - "--adapter-path", - type=str, - help="Save/load path for the fine-tuned weights.", - ) - parser.add_argument( - "--save-every", - type=int, - help="Save the model every N iterations.", - ) - parser.add_argument( - "--test", - action="store_true", - help="Evaluate on the test set after training", - default=None, - ) - parser.add_argument( - "--test-batches", - type=int, - help="Number of test set batches, -1 uses the entire test set.", - ) - parser.add_argument( - "--max-seq-length", - type=int, - help="Maximum sequence length.", - ) - parser.add_argument( - "-c", - "--config", - type=str, - help="A YAML configuration file with the training options", - ) - parser.add_argument( - "--grad-checkpoint", - action="store_true", - help="Use gradient checkpointing to reduce memory use.", - default=None, - ) - parser.add_argument("--seed", type=int, help="The PRNG seed") - return parser - - -def train_model( - args, - model: nn.Module, - tokenizer: TokenizerWrapper, - train_set, - valid_set, - training_callback: TrainingCallback = None, -): - model.freeze() - if args.num_layers > len(model.layers): - raise ValueError( - f"Requested to train {args.num_layers} layers " - f"but the model only has {len(model.layers)} layers." - ) - - if args.fine_tune_type == "full": - for l in model.layers[-max(args.num_layers, 0) :]: - l.unfreeze() - elif args.fine_tune_type in ["lora", "dora"]: - # Convert linear layers to lora/dora layers and unfreeze in the process - linear_to_lora_layers( - model, - args.num_layers, - args.lora_parameters, - use_dora=(args.fine_tune_type == "dora"), - ) - else: - raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}") - - # Resume from weights if provided - if args.resume_adapter_file is not None: - print(f"Loading fine-tuned weights from {args.resume_adapter_file}") - model.load_weights(args.resume_adapter_file, strict=False) - - print_trainable_parameters(model) - - adapter_path = Path(args.adapter_path) - adapter_path.mkdir(parents=True, exist_ok=True) - - adapter_file = adapter_path / "adapters.safetensors" - save_config(vars(args), adapter_path / "adapter_config.json") - - # init training args - training_args = TrainingArgs( - batch_size=args.batch_size, - iters=args.iters, - val_batches=args.val_batches, - steps_per_report=args.steps_per_report, - steps_per_eval=args.steps_per_eval, - steps_per_save=args.save_every, - adapter_file=adapter_file, - max_seq_length=args.max_seq_length, - grad_checkpoint=args.grad_checkpoint, - ) - - model.train() - - # Initialize the selected optimizer - lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - - optimizer_name = args.optimizer.lower() - optimizer_config = args.optimizer_config.get(optimizer_name, {}) - - if optimizer_name == "adam": - opt_class = optim.Adam - elif optimizer_name == "adamw": - opt_class = optim.AdamW - else: - raise ValueError(f"Unsupported optimizer: {optimizer_name}") - - opt = opt_class(learning_rate=lr, **optimizer_config) - - # Train model - train( - model=model, - tokenizer=tokenizer, - args=training_args, - optimizer=opt, - train_dataset=train_set, - val_dataset=valid_set, - training_callback=training_callback, - ) - - -def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set): - model.eval() - - test_loss = evaluate( - model=model, - dataset=test_set, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.test_batches, - max_seq_length=args.max_seq_length, - ) - - test_ppl = math.exp(test_loss) - - print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") - - -def run(args, training_callback: TrainingCallback = None): - np.random.seed(args.seed) - - print("Loading pretrained model") - model, tokenizer = load(args.model) - - print("Loading datasets") - train_set, valid_set, test_set = load_dataset(args, tokenizer) - - if args.test and not args.train: - # Allow testing without LoRA layers by providing empty path - if args.adapter_path != "": - load_adapters(model, args.adapter_path) - - elif args.train: - print("Training") - train_model(args, model, tokenizer, train_set, valid_set, training_callback) - else: - raise ValueError("Must provide at least one of --train or --test") - - if args.test: - print("Testing") - evaluate_model(args, model, tokenizer, test_set) - - -def main(): - os.environ["TOKENIZERS_PARALLELISM"] = "true" - parser = build_parser() - args = parser.parse_args() - config = args.config - args = vars(args) - if config: - print("Loading configuration file", config) - with open(config, "r") as file: - config = yaml.load(file, yaml_loader) - # Prefer parameters from command-line arguments - for k, v in config.items(): - if args.get(k, None) is None: - args[k] = v - - # Update defaults for unspecified parameters - for k, v in CONFIG_DEFAULTS.items(): - if args.get(k, None) is None: - args[k] = v - run(types.SimpleNamespace(**args)) - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py deleted file mode 100644 index c06de6b3..00000000 --- a/llms/mlx_lm/manage.py +++ /dev/null @@ -1,139 +0,0 @@ -import argparse -from typing import List, Union - -from huggingface_hub import scan_cache_dir - - -def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: - """ - Inspired by: - - stackoverflow.com/a/8356620/593036 - - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data - """ - col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] - row_format = ("{{:{}}} " * len(headers)).format(*col_widths) - lines = [] - lines.append(row_format.format(*headers)) - lines.append(row_format.format(*["-" * w for w in col_widths])) - for row in rows: - lines.append(row_format.format(*row)) - return "\n".join(lines) - - -def ask_for_confirmation(message: str) -> bool: - """Ask user for confirmation with Y/N prompt. - Returns True for Y/yes, False for N/no/empty.""" - y = ("y", "yes", "1") - n = ("n", "no", "0", "") - full_message = f"{message} (y/n) " - while True: - answer = input(full_message).lower() - if answer in y: - return True - if answer in n: - return False - print(f"Invalid input. Must be one of: yes/no/y/n or empty for no") - - -def main(): - parser = argparse.ArgumentParser(description="MLX Model Cache.") - parser.add_argument( - "--scan", - action="store_true", - help="Scan Hugging Face cache for mlx models.", - ) - parser.add_argument( - "--delete", - action="store_true", - help="Delete models matching the given pattern.", - ) - parser.add_argument( - "--pattern", - type=str, - help="Model repos contain the pattern.", - default="mlx", - ) - - args = parser.parse_args() - - if args.scan: - print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".') - hf_cache_info = scan_cache_dir() - print( - tabulate( - rows=[ - [ - repo.repo_id, - repo.repo_type, - "{:>12}".format(repo.size_on_disk_str), - repo.nb_files, - repo.last_accessed_str, - repo.last_modified_str, - str(repo.repo_path), - ] - for repo in sorted( - hf_cache_info.repos, key=lambda repo: repo.repo_path - ) - if args.pattern in repo.repo_id - ], - headers=[ - "REPO ID", - "REPO TYPE", - "SIZE ON DISK", - "NB FILES", - "LAST_ACCESSED", - "LAST_MODIFIED", - "LOCAL PATH", - ], - ) - ) - - if args.delete: - print(f'Deleting models matching pattern "{args.pattern}"') - hf_cache_info = scan_cache_dir() - - repos = [ - repo - for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) - if args.pattern in repo.repo_id - ] - if repos: - print("\nFound the following models:") - print( - tabulate( - rows=[ - [ - repo.repo_id, - repo.size_on_disk_str, # Added size information - str(repo.repo_path), - ] - for repo in repos - ], - headers=[ - "REPO ID", - "SIZE", # Added size header - "LOCAL PATH", - ], - ) - ) - - confirmed = ask_for_confirmation( - "\nAre you sure you want to delete these models?" - ) - if confirmed: - for model_info in repos: - print(f"\nDeleting {model_info.repo_id}...") - for revision in sorted( - model_info.revisions, key=lambda revision: revision.commit_hash - ): - strategy = hf_cache_info.delete_revisions(revision.commit_hash) - strategy.execute() - print("\nModel(s) deleted successfully.") - else: - print("\nDeletion cancelled - no changes made.") - else: - print(f'No models found matching pattern "{args.pattern}"') - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/merge.py b/llms/mlx_lm/merge.py deleted file mode 100644 index a009338e..00000000 --- a/llms/mlx_lm/merge.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import argparse -import glob -import shutil -from pathlib import Path -from typing import Optional - -import mlx.core as mx -import mlx.nn as nn -import numpy as np -import yaml -from mlx.utils import tree_flatten, tree_map - -from .utils import ( - fetch_from_hub, - get_model_path, - save_config, - save_weights, - upload_to_hub, -) - - -def configure_parser() -> argparse.ArgumentParser: - """ - Configures and returns the argument parser for the script. - - Returns: - argparse.ArgumentParser: Configured argument parser. - """ - parser = argparse.ArgumentParser(description="Merge multiple models.") - - parser.add_argument("--config", type=str, help="Path to the YAML config.") - parser.add_argument( - "--mlx-path", - type=str, - default="mlx_merged_model", - help="Path to save the MLX model.", - ) - parser.add_argument( - "--upload-repo", - help="The Hugging Face repo to upload the model to.", - type=str, - default=None, - ) - return parser - - -def slerp(t, w1, w2, eps=1e-5): - """ - Spherical linear interpolation - - Args: - t (float): Interpolation weight in [0.0, 1.0] - w1 (mx.array): First input - w2 (mx.array): Second input - eps (float): Constant for numerical stability - Returns: - mx.array: Interpolated result - """ - t = float(t) - if t == 0: - return w1 - elif t == 1: - return w2 - # Normalize - v1 = w1 / mx.linalg.norm(w1) - v2 = w2 / mx.linalg.norm(w2) - # Angle - dot = mx.clip((v1 * v2).sum(), 0.0, 1.0) - theta = mx.arccos(dot) - sin_theta = mx.sin(theta + eps) - s1 = mx.sin(theta * (1 - t)) / sin_theta - s2 = mx.sin(theta * t) / sin_theta - return s1 * w1 + s2 * w2 - - -def merge_models(base_model: nn.Module, model: nn.Module, config: dict): - method = config.get("method", None) - if method != "slerp": - raise ValueError(f"Merge method {method} not supported") - - num_layers = len(model.layers) - - def unpack_values(vals): - if isinstance(vals, (int, float)): - return np.full(num_layers, vals) - bins = len(vals) - 1 - sizes = [num_layers // bins] * bins - sizes[-1] = num_layers - sum(sizes[:-1]) - return np.concatenate( - [np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)] - ) - - param_list = config["parameters"]["t"] - params = {} - filter_keys = set() - for pl in param_list[:-1]: - params[pl["filter"]] = unpack_values(pl["value"]) - filter_keys.add(pl["filter"]) - default = unpack_values(param_list[-1]["value"]) - - for e in range(num_layers): - bl = base_model.layers[e] - l = model.layers[e] - base_weights = bl.parameters() - weights = l.parameters() - for k, w1 in base_weights.items(): - w2 = weights[k] - t = params.get(k, default)[e] - base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2) - base_model.update(base_weights) - - -def merge( - config: str, - mlx_path: str = "mlx_model", - upload_repo: Optional[str] = None, -): - with open(config, "r") as fid: - merge_conf = yaml.safe_load(fid) - print("[INFO] Loading") - - model_paths = merge_conf.get("models", []) - if len(model_paths) < 2: - raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.") - - # Load all models - base_hf_path = model_paths[0] - base_path = get_model_path(base_hf_path) - base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True) - models = [] - for mp in model_paths[1:]: - model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True) - base_type = base_config["model_type"] - model_type = model_config["model_type"] - if base_type != model_type: - raise ValueError( - f"Can only merge models of the same type," - f" but got {base_type} and {model_type}." - ) - models.append(model) - - # Merge models into base model - for m in models: - merge_models(base_model, m, merge_conf) - - # Save base model - mlx_path = Path(mlx_path) - weights = dict(tree_flatten(base_model.parameters())) - del models, base_model - save_weights(mlx_path, weights, donate_weights=True) - py_files = glob.glob(str(base_path / "*.py")) - for file in py_files: - shutil.copy(file, mlx_path) - - tokenizer.save_pretrained(mlx_path) - - save_config(config, config_path=mlx_path / "config.json") - - if upload_repo is not None: - upload_to_hub(mlx_path, upload_repo, base_hf_path) - - -def main(): - parser = configure_parser() - args = parser.parse_args() - merge(**vars(args)) - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/models/__init__.py b/llms/mlx_lm/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py deleted file mode 100644 index 8b40effb..00000000 --- a/llms/mlx_lm/models/base.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import inspect -from dataclasses import dataclass -from typing import Any, Optional - -import mlx.core as mx -from mlx.utils import tree_map - -from .cache import QuantizedKVCache - - -@dataclass -class BaseModelArgs: - @classmethod - def from_dict(cls, params): - return cls( - **{ - k: v - for k, v in params.items() - if k in inspect.signature(cls).parameters - } - ) - - -def create_causal_mask( - N: int, - offset: int = 0, - window_size: Optional[int] = None, - lengths: Optional[mx.array] = None, -): - rinds = mx.arange(offset + N) - linds = mx.arange(offset, offset + N) if offset else rinds - linds = linds[:, None] - rinds = rinds[None] - mask = linds >= rinds - if window_size is not None: - mask = mask & (linds <= rinds + window_size) - if lengths is not None: - lengths = lengths[:, None, None, None] - mask = mask & (rinds < lengths) - return mask - - -def create_attention_mask(h: mx.array, cache: Optional[Any] = None): - T = h.shape[1] - if T > 1: - window_size = None - offset = 0 - if cache is not None and cache[0] is not None: - c = cache[0] - if hasattr(c, "max_size"): - offset = min(c.max_size, c.offset) - window_size = c.max_size - else: - offset = c.offset - mask = create_causal_mask(T, offset, window_size=window_size) - else: - mask = None - return mask - - -def quantized_scaled_dot_product_attention( - queries: mx.array, - q_keys: tuple[mx.array, mx.array, mx.array], - q_values: tuple[mx.array, mx.array, mx.array], - scale: float, - mask: Optional[mx.array], - group_size: int = 64, - bits: int = 8, -) -> mx.array: - B, n_q_heads, L, D = queries.shape - n_kv_heads = q_keys[0].shape[-3] - n_repeats = n_q_heads // n_kv_heads - - queries *= scale - - if n_repeats > 1: - queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) - q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) - q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) - - scores = mx.quantized_matmul( - queries, *q_keys, transpose=True, group_size=group_size, bits=bits - ) - if mask is not None: - scores += mask - scores = mx.softmax(scores, axis=-1, precise=True) - out = mx.quantized_matmul( - scores, *q_values, transpose=False, group_size=group_size, bits=bits - ) - - if n_repeats > 1: - out = mx.reshape(out, (B, n_q_heads, L, D)) - - return out - - -def scaled_dot_product_attention( - queries, - keys, - values, - cache, - scale: float, - mask: Optional[mx.array], -) -> mx.array: - if isinstance(cache, QuantizedKVCache): - return quantized_scaled_dot_product_attention( - queries, - keys, - values, - scale=scale, - mask=mask, - group_size=cache.group_size, - bits=cache.bits, - ) - else: - return mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=scale, mask=mask - ) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py deleted file mode 100644 index 14026f0c..00000000 --- a/llms/mlx_lm/models/cache.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from typing import Any, Dict, List, Optional - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_flatten, tree_map, tree_unflatten - - -def make_prompt_cache( - model: nn.Module, - max_kv_size: Optional[int] = None, -) -> List[Any]: - """ - Construct the model's cache for use when cgeneration. - - This function will defer the cache construction to the model if it has a - ``make_cache`` method, otherwise it will make a default KV cache. - - Args: - model (nn.Module): The language model. - max_kv_size (Optional[int]): If provided and the model does not have a - ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum - size of ``max_kv_size`` - """ - if hasattr(model, "make_cache"): - return model.make_cache() - - num_layers = len(model.layers) - if max_kv_size is not None: - return [ - RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) - ] - else: - return [KVCache() for _ in range(num_layers)] - - -def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): - """ - Save a pre-computed prompt cache to a file. - - Args: - file_name (str): The ``.safetensors`` file name. - cache (List[Any]): The model state. - metadata (Dict[str, str]): Optional metadata to save along with model - state. - """ - cache_data = [c.state for c in cache] - cache_info = [c.meta_state for c in cache] - cache_data = dict(tree_flatten(cache_data)) - cache_classes = [type(c).__name__ for c in cache] - cache_metadata = [cache_info, metadata, cache_classes] - cache_metadata = dict(tree_flatten(cache_metadata)) - mx.save_safetensors(file_name, cache_data, cache_metadata) - - -def load_prompt_cache(file_name, return_metadata=False): - """ - Load a prompt cache from a file. - - Args: - file_name (str): The ``.safetensors`` file name. - return_metadata (bool): Whether or not to return metadata. - Default: ``False``. - - Returns: - List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and - the metadata if requested. - """ - arrays, cache_metadata = mx.load(file_name, return_metadata=True) - arrays = tree_unflatten(list(arrays.items())) - cache_metadata = tree_unflatten(list(cache_metadata.items())) - info, metadata, classes = cache_metadata - cache = [globals()[c]() for c in classes] - for c, state, meta_state in zip(cache, arrays, info): - c.state = state - c.meta_state = meta_state - if return_metadata: - return cache, metadata - return cache - - -def can_trim_prompt_cache(cache: List[Any]) -> bool: - """ - Check if model's cache can be trimmed. - """ - return all(c.is_trimmable() for c in cache) - - -def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: - """ - Trim the model's cache by the given number of tokens. - - This function will trim the cache if possible (in-place) and return the - number of tokens that were trimmed. - - Args: - cache (List[Any]): The model's cache. - num_tokens (int): The number of tokens to trim. - - Returns: - (int): The number of tokens that were trimmed. - """ - if not can_trim_prompt_cache(cache) or len(cache) == 0: - return 0 - return [c.trim(num_tokens) for c in cache][0] - - -class _BaseCache: - @property - def state(self): - return [] - - @state.setter - def state(self, v): - if v is not None and v: - raise ValueError("This cache has no state but a state was set.") - - @property - def meta_state(self): - return "" - - @meta_state.setter - def meta_state(self, v): - if v is not None and v: - raise ValueError("This cache has no meta_state but a meta_state was set.") - - def is_trimmable(self): - return False - - -class QuantizedKVCache(_BaseCache): - def __init__(self, group_size: int = 64, bits: int = 8): - self.keys = None - self.values = None - self.offset = 0 - self.step = 256 - self.group_size = group_size - self.bits = bits - - def update_and_fetch(self, keys, values): - B, n_kv_heads, num_steps, k_head_dim = keys.shape - v_head_dim = values.shape[-1] - prev = self.offset - - if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: - el_per_int = 8 * mx.uint32.size // self.bits - new_steps = (self.step + num_steps - 1) // self.step * self.step - shape = (B, n_kv_heads, new_steps) - - def init_quant(dim): - return ( - mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32), - mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), - mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), - ) - - def expand_quant(x): - new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype) - return mx.concatenate([x, new_x], axis=-2) - - if self.keys is not None: - if prev % self.step != 0: - self.keys, self.values = tree_map( - lambda x: x[..., :prev, :], (self.keys, self.values) - ) - - self.keys, self.values = tree_map( - expand_quant, (self.keys, self.values) - ) - else: - self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim) - - self.offset += num_steps - - keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) - values = mx.quantize(values, group_size=self.group_size, bits=self.bits) - for i in range(len(self.keys)): - self.keys[i][..., prev : self.offset, :] = keys[i] - self.values[i][..., prev : self.offset, :] = values[i] - - return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values)) - - @property - def state(self): - if self.offset == self.keys[0].shape[2]: - return self.keys, self.values - else: - return tree_map( - lambda x: x[..., : self.offset, :], (self.keys, self.values) - ) - - @state.setter - def state(self, v): - self.keys, self.values = v - - @property - def meta_state(self): - return tuple(map(str, (self.step, self.offset, self.group_size, self.bits))) - - @meta_state.setter - def meta_state(self, v): - self.step, self.offset, self.group_size, self.bits = map(int, v) - - def is_trimmable(self): - return True - - def trim(self, n): - n = min(self.offset, n) - self.offset -= n - return n - - -class KVCache(_BaseCache): - def __init__(self): - self.keys = None - self.values = None - self.offset = 0 - self.step = 256 - - def update_and_fetch(self, keys, values): - prev = self.offset - if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: - B, n_kv_heads, _, k_head_dim = keys.shape - v_head_dim = values.shape[3] - n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) - v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - if prev % self.step != 0: - self.keys = self.keys[..., :prev, :] - self.values = self.values[..., :prev, :] - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - - self.offset += keys.shape[2] - self.keys[..., prev : self.offset, :] = keys - self.values[..., prev : self.offset, :] = values - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - - @property - def state(self): - if self.offset == self.keys.shape[2]: - return self.keys, self.values - else: - return ( - self.keys[..., : self.offset, :], - self.values[..., : self.offset, :], - ) - - @state.setter - def state(self, v): - self.keys, self.values = v - self.offset = self.keys.shape[2] - - def is_trimmable(self): - return True - - def trim(self, n): - n = min(self.offset, n) - self.offset -= n - return n - - def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: - quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) - quant_cache.offset = self.offset - if self.keys is not None: - quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) - quant_cache.values = mx.quantize( - self.values, group_size=group_size, bits=bits - ) - return quant_cache - - -class RotatingKVCache(_BaseCache): - - def __init__(self, max_size=None, keep=0, step=256): - self.keep = keep - self.keys = None - self.values = None - self.offset = 0 - self.max_size = max_size - self.step = step - self._idx = 0 - - def _trim(self, trim_size, v, append=None): - to_cat = [] - if trim_size > 0: - to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] - else: - to_cat = [v] - if append is not None: - to_cat.append(append) - return mx.concatenate(to_cat, axis=2) - - def _temporal_order(self, v): - """ - Rearrange the cache into temporal order, slicing off the end if unused. - """ - if self._idx == v.shape[2]: - return v - elif self._idx < self.offset: - return mx.concatenate( - [ - v[..., : self.keep, :], - v[..., self._idx :, :], - v[..., self.keep : self._idx, :], - ], - axis=2, - ) - else: - return v[..., : self._idx, :] - - def _update_concat(self, keys, values): - if self.keys is None: - self.keys = keys - self.values = values - else: - # Put the keys/values in temporal order to - # preserve context - self.keys = self._temporal_order(self.keys) - self.values = self._temporal_order(self.values) - - # The largest size is self.max_size + S to ensure - # every token gets at least self.max_size context - trim_size = self._idx - self.max_size - self.keys = self._trim(trim_size, self.keys, keys) - self.values = self._trim(trim_size, self.values, values) - self.offset += keys.shape[2] - self._idx = self.keys.shape[2] - return self.keys, self.values - - def _update_in_place(self, keys, values): - # May not have hit the max size yet, so potentially - # keep growing the cache - B, n_kv_heads, S, k_head_dim = keys.shape - prev = self.offset - if self.keys is None or ( - prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size - ): - v_head_dim = values.shape[3] - new_size = min(self.step, self.max_size - prev) - k_shape = (B, n_kv_heads, new_size, k_head_dim) - v_shape = (B, n_kv_heads, new_size, v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - self._idx = prev - - # Trim if needed - trim_size = self.keys.shape[2] - self.max_size - if trim_size > 0: - self.keys = self._trim(trim_size, self.keys) - self.values = self._trim(trim_size, self.values) - self._idx = self.max_size - - # Rotate - if self._idx == self.max_size: - self._idx = self.keep - - # Assign - self.keys[..., self._idx : self._idx + S, :] = keys - self.values[..., self._idx : self._idx + S, :] = values - self.offset += S - self._idx += S - - # If the buffer is not full, slice off the end - if self.offset < self.max_size: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - return self.keys, self.values - - def update_and_fetch(self, keys, values): - if keys.shape[2] == 1: - return self._update_in_place(keys, values) - return self._update_concat(keys, values) - - @property - def state(self): - if self.offset < self.keys.shape[2]: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - else: - return self.keys, self.values - - @state.setter - def state(self, v): - self.keys, self.values = v - - @property - def meta_state(self): - return tuple( - map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)) - ) - - @meta_state.setter - def meta_state(self, v): - self.keep, self.max_size, self.step, self.offset, self._idx = map( - int, - v, - ) - - def is_trimmable(self): - return self.offset < self.max_size - - def trim(self, n): - n = min(self.offset, n) - self.offset -= n - self._idx -= n - return n - - def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: - raise NotImplementedError("RotatingKVCache Quantization NYI") - - -class MambaCache(_BaseCache): - def __init__(self): - self.cache = [None, None] - - def __setitem__(self, idx, value): - self.cache[idx] = value - - def __getitem__(self, idx): - return self.cache[idx] - - @property - def state(self): - return self.cache - - @state.setter - def state(self, v): - self.cache = v diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py deleted file mode 100644 index b2d16dd7..00000000 --- a/llms/mlx_lm/models/cohere.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int = 8192 - num_hidden_layers: int = 40 - intermediate_size: int = 22528 - num_attention_heads: int = 64 - num_key_value_heads: int = 64 - rope_theta: float = 8000000.0 - vocab_size: int = 256000 - layer_norm_eps: float = 1e-05 - logit_scale: float = 0.0625 - attention_bias: bool = False - layer_norm_bias: bool = False - use_qk_norm: bool = False - - -class LayerNorm2D(nn.Module): - - def __init__(self, d1, d2, eps): - super().__init__() - self.weight = mx.zeros((d1, d2)) - self.eps = eps - - def __call__(self, x): - return self.weight * mx.fast.layer_norm(x, None, None, self.eps) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // args.num_attention_heads - self.scale = head_dim**-0.5 - - attetion_bias = args.attention_bias - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) - - self.use_qk_norm = args.use_qk_norm - if self.use_qk_norm: - self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps) - self.k_norm = LayerNorm2D( - self.n_kv_heads, head_dim, eps=args.layer_norm_eps - ) - - self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - queries = queries.reshape(B, L, self.n_heads, -1) - keys = keys.reshape(B, L, self.n_kv_heads, -1) - if self.use_qk_norm: - queries = self.q_norm(queries) - keys = self.k_norm(keys) - - queries = queries.transpose(0, 2, 1, 3) - keys = keys.transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - - def __call__(self, x): - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - self.n_heads = args.num_attention_heads - - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = nn.LayerNorm( - args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - h = self.input_layernorm(x) - attn_h = self.self_attn(h, mask, cache) - ff_h = self.mlp(h) - return attn_h + ff_h + x - - -class CohereModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.LayerNorm( - args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias - ) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = CohereModel(args) - self.args = args - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - out = self.model.embed_tokens.as_linear(out) - out = out * self.model.args.logit_scale - return out - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py deleted file mode 100644 index 19bfa6b6..00000000 --- a/llms/mlx_lm/models/cohere2.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .cache import KVCache, RotatingKVCache - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int = 4096 - head_dim: int = 128 - num_hidden_layers: int = 32 - intermediate_size: int = 14336 - num_attention_heads: int = 32 - num_key_value_heads: int = 8 - rope_theta: float = 50000.0 - vocab_size: int = 256000 - layer_norm_eps: float = 1e-05 - logit_scale: float = 0.0625 - attention_bias: bool = False - layer_norm_bias: bool = False - sliding_window: int = 4096 - sliding_window_pattern: int = 4 - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__() - self.args = args - self.layer_idx = layer_idx - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.head_dim = head_dim = args.head_dim - if (head_dim * n_heads) != dim: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}" - f" and `num_heads`: {n_heads})." - ) - self.scale = head_dim**-0.5 - - attetion_bias = args.attention_bias - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) - - self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) - - self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0 - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - # Apply RoPE only if sliding window is enabled - if self.use_sliding_window: - if cache is None: - queries = self.rope(queries) - keys = self.rope(keys) - else: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - - if cache is not None: - keys, values = cache.update_and_fetch(keys, values) - - if self.use_sliding_window and mask is not None: - key_len = keys.shape[-2] - if mask.shape[-1] != key_len: - mask = mask[..., -key_len:] - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - - def __call__(self, x): - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__() - self.hidden_size = args.hidden_size - self.n_heads = args.num_attention_heads - - self.self_attn = Attention(args, layer_idx) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = nn.LayerNorm( - args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - h = self.input_layernorm(x) - attn_h = self.self_attn(h, mask, cache) - ff_h = self.mlp(h) - return attn_h + ff_h + x - - -class CohereModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args, layer_idx=i) - for i in range(args.num_hidden_layers) - ] - self.norm = nn.LayerNorm( - args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias - ) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if cache is None: - cache = [None] * len(self.layers) - - if mask is None: - j = self.args.sliding_window_pattern - mask = create_attention_mask(h, cache[j - 1 : j]) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = CohereModel(args) - self.args = args - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - out = self.model.embed_tokens.as_linear(out) - out = out * self.model.args.logit_scale - return out - - def make_cache(self): - caches = [] - for i in range(self.args.num_hidden_layers): - if ( - i % self.args.sliding_window_pattern - == self.args.sliding_window_pattern - 1 - ): - caches.append(KVCache()) - else: - caches.append( - RotatingKVCache(max_size=self.args.sliding_window, keep=0) - ) - return caches - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py deleted file mode 100644 index 886b5630..00000000 --- a/llms/mlx_lm/models/dbrx.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - d_model: int - ffn_config: dict - attn_config: dict - n_layers: int - n_heads: int - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_heads = args.n_heads - self.d_model = args.d_model - self.head_dim = args.d_model // args.n_heads - self.num_key_value_heads = args.attn_config["kv_n_heads"] - self.clip_qkv = args.attn_config["clip_qkv"] - self.rope_theta = args.attn_config["rope_theta"] - - self.scale = self.head_dim**-0.5 - - self.Wqkv = nn.Linear( - args.d_model, - (self.num_key_value_heads * 2 + self.num_heads) * self.head_dim, - bias=False, - ) - self.out_proj = nn.Linear(args.d_model, args.d_model, bias=False) - self.rope = nn.RoPE( - self.head_dim, - traditional=False, - base=self.rope_theta, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - - qkv = self.Wqkv(x) - qkv = mx.clip(qkv, a_min=-self.clip_qkv, a_max=self.clip_qkv) - splits = [self.d_model, self.d_model + self.head_dim * self.num_key_value_heads] - queries, keys, values = mx.split(qkv, splits, axis=-1) - - B, L, D = x.shape - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( - 0, 2, 1, 3 - ) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) - - -class NormAttnNorm(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.norm_1 = nn.LayerNorm(args.d_model, bias=False) - self.norm_2 = nn.LayerNorm(args.d_model, bias=False) - self.attn = Attention(args) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - h = self.attn(self.norm_1(x), mask=mask, cache=cache) - x = h + x - return x, self.norm_2(x) - - -class MLP(nn.Module): - def __init__(self, d_model: int, ffn_dim: int): - super().__init__() - self.v1 = nn.Linear(d_model, ffn_dim, bias=False) - self.w1 = nn.Linear(d_model, ffn_dim, bias=False) - self.w2 = nn.Linear(ffn_dim, d_model, bias=False) - self.act_fn = nn.silu - - def __call__(self, x: mx.array) -> mx.array: - current_hidden_states = self.act_fn(self.w1(x)) * self.v1(x) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class Router(nn.Module): - def __init__(self, d_model: int, num_experts: int): - super().__init__() - self.layer = nn.Linear(d_model, num_experts, bias=False) - - def __call__(self, x: mx.array): - return self.layer(x) - - -class SparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.d_model = args.d_model - self.ffn_dim = args.ffn_config["ffn_hidden_size"] - self.num_experts = args.ffn_config["moe_num_experts"] - self.num_experts_per_tok = args.ffn_config["moe_top_k"] - - self.router = Router(self.d_model, self.num_experts) - self.experts = [ - MLP(self.d_model, self.ffn_dim) for _ in range(self.num_experts) - ] - - def __call__(self, x: mx.array) -> mx.array: - ne = self.num_experts_per_tok - orig_shape = x.shape - x = x.reshape(-1, x.shape[-1]) - - gates = self.router(x) - gates = mx.softmax(gates.astype(mx.float32), axis=-1) - - inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]) - scores = mx.take_along_axis(gates, inds, axis=-1) - scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True) - scores = scores.astype(x.dtype) - - if self.training: - inds = np.array(inds) - y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype) - for e, expert in enumerate(self.experts): - idx1, idx2 = map(mx.array, np.where(inds == e)) - if idx1.size == 0: - continue - y[idx1, idx2] = expert(x[idx1]) - - y = (y * scores[:, :, None]).sum(axis=1) - else: - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.stack([self.experts[e](xt) for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt) - y = mx.stack(y, axis=0) - - return y.reshape(orig_shape) - - -class DecoderLayer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.ffn = SparseMoeBlock(args) - self.norm_attn_norm = NormAttnNorm(args) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r, h = self.norm_attn_norm(x, mask, cache) - out = self.ffn(h) + r - return out - - -class DBRX(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.vocab_size = args.vocab_size - self.wte = nn.Embedding(args.vocab_size, args.d_model) - self.blocks = [DecoderLayer(args=args) for _ in range(args.n_layers)] - self.norm_f = nn.LayerNorm(args.d_model, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.wte(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.blocks) - - for layer, c in zip(self.blocks, cache): - h = layer(h, mask, c) - - return self.norm_f(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.transformer = DBRX(args) - self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) - self.args = args - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.transformer(inputs, mask, cache) - return self.lm_head(out) - - @property - def layers(self): - return self.transformer.blocks - - def sanitize(self, weights): - # Split experts into sub matrices - num_experts = self.args.ffn_config["moe_num_experts"] - dim = self.args.ffn_config["ffn_hidden_size"] - - pattern = "experts.mlp" - new_weights = {k: v for k, v in weights.items() if pattern not in k} - for k, v in weights.items(): - if pattern in k: - experts = [ - (k.replace(".mlp", f".{e}") + ".weight", sv) - for e, sv in enumerate(mx.split(v, num_experts, axis=0)) - ] - if k.endswith("w2"): - experts = [(s, sv.T) for s, sv in experts] - new_weights.update(experts) - return new_weights diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py deleted file mode 100644 index ffc30c36..00000000 --- a/llms/mlx_lm/models/deepseek.py +++ /dev/null @@ -1,261 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, Optional - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "deepseek" - vocab_size: int = 102400 - hidden_size: int = 4096 - intermediate_size: int = 11008 - moe_intermediate_size: int = 1407 - num_hidden_layers: int = 30 - num_attention_heads: int = 32 - num_key_value_heads: int = 32 - n_shared_experts: Optional[int] = None - n_routed_experts: Optional[int] = None - num_experts_per_tok: Optional[int] = None - moe_layer_freq: int = 1 - first_k_dense_replace: int = 0 - max_position_embeddings: int = 2048 - rms_norm_eps: float = 1e-6 - rope_theta: float = 10000.0 - rope_scaling: Optional[Dict] = None - attention_bias: bool = False - - -class DeepseekAttention(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.scale = self.head_dim**-0.5 - - attention_bias = getattr(config, "attention_bias", False) - - self.q_proj = nn.Linear( - self.hidden_size, - config.num_attention_heads * self.head_dim, - bias=attention_bias, - ) - self.k_proj = nn.Linear( - self.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=attention_bias, - ) - self.v_proj = nn.Linear( - self.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=attention_bias, - ) - self.o_proj = nn.Linear( - self.hidden_size, - config.num_attention_heads * self.head_dim, - bias=attention_bias, - ) - - rope_scale = 1.0 - if config.rope_scaling and config.rope_scaling["type"] == "linear": - assert isinstance(config.rope_scaling["factor"], float) - rope_scale = 1 / config.rope_scaling["factor"] - self.rope = nn.RoPE( - self.head_dim, - base=config.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, _ = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose( - 0, 2, 1, 3 - ) - keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class DeepseekMLP(nn.Module): - def __init__( - self, - config: ModelArgs, - hidden_size: Optional[int] = None, - intermediate_size: Optional[int] = None, - ): - super().__init__() - self.config = config - self.hidden_size = hidden_size or config.hidden_size - self.intermediate_size = intermediate_size or config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = nn.silu - - def __call__(self, x: mx.array) -> mx.array: - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class MoEGate(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) - - def __call__(self, x): - gates = x @ self.weight.T - scores = mx.softmax(gates, axis=-1, precise=True) - k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) - scores = mx.take_along_axis(scores, inds, axis=-1) - return inds, scores - - -class DeepseekMoE(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.switch_mlp = SwitchGLU( - config.hidden_size, config.moe_intermediate_size, config.n_routed_experts - ) - - self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekMLP( - config=config, intermediate_size=intermediate_size - ) - - def __call__(self, x): - inds, scores = self.gate(x) - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(x) - - return y - - -class DeepseekDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs, layer_idx: int): - super().__init__() - self.self_attn = DeepseekAttention(config) - self.mlp = ( - DeepseekMoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekMLP(config) - ) - self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class DeepseekModel(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [ - DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers) - ] - self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def __call__( - self, - x: mx.array, - cache: Optional[Any] = None, - mask: Optional[mx.array] = None, - ) -> mx.array: - h = self.embed_tokens(x) - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.args = config - self.model_type = config.model_type - self.model = DeepseekModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache: Optional[Any] = None, - mask: Optional[mx.array] = None, - ): - out = self.model(inputs, cache, mask) - return self.lm_head(out) - - def sanitize(self, weights): - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for m in ["gate_proj", "down_proj", "up_proj"]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") - for e in range(self.args.n_routed_experts) - ] - weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) - return weights - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py deleted file mode 100644 index 7a5bdeb1..00000000 --- a/llms/mlx_lm/models/deepseek_v2.py +++ /dev/null @@ -1,462 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "deepseek_v2" - vocab_size: int = 102400 - hidden_size: int = 4096 - intermediate_size: int = 11008 - moe_intermediate_size: int = 1407 - num_hidden_layers: int = 30 - num_attention_heads: int = 32 - num_key_value_heads: int = 32 - n_shared_experts: Optional[int] = None - n_routed_experts: Optional[int] = None - routed_scaling_factor: float = 1.0 - kv_lora_rank: int = 512 - q_lora_rank: int = 1536 - qk_rope_head_dim: int = 64 - v_head_dim: int = 128 - qk_nope_head_dim: int = 128 - topk_method: str = "gready" - n_group: Optional[int] = None - topk_group: Optional[int] = None - num_experts_per_tok: Optional[int] = None - moe_layer_freq: int = 1 - first_k_dense_replace: int = 0 - max_position_embeddings: int = 2048 - rms_norm_eps: float = 1e-6 - rope_theta: float = 10000.0 - rope_scaling: Dict = None - attention_bias: bool = False - - -def yarn_find_correction_dim( - num_rotations, dim, base=10000, max_position_embeddings=2048 -): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) - - -def yarn_find_correction_range( - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 -): - low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) - - -def yarn_get_mscale(scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -def yarn_linear_ramp_mask(min_val, max_val, dim): - if min_val == max_val: - max_val += 0.001 # Prevent singularity - - linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) - return mx.clip(linear_func, 0, 1) - - -class DeepseekV2YarnRotaryEmbedding(nn.Module): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1.0, - original_max_position_embeddings=4096, - beta_fast=32, - beta_slow=1, - mscale=1, - mscale_all_dim=0, - ): - super().__init__() - self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( - scaling_factor, mscale_all_dim - ) - freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) - freq_inter = scaling_factor * base ** ( - mx.arange(0, dim, 2, dtype=mx.float32) / dim - ) - low, high = yarn_find_correction_range( - beta_fast, - beta_slow, - dim, - base, - original_max_position_embeddings, - ) - freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) - self._freqs = (freq_inter * freq_extra) / ( - freq_inter * freq_mask + freq_extra * (1 - freq_mask) - ) - - def __call__(self, x, offset=0): - if self.mscale != 1.0: - x = self.mscale * x - return mx.fast.rope( - x, - x.shape[-1], - traditional=True, - base=None, - scale=1.0, - offset=offset, - freqs=self._freqs, - ) - - -class DeepseekV2Attention(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.q_lora_rank = config.q_lora_rank - self.qk_rope_head_dim = config.qk_rope_head_dim - self.kv_lora_rank = config.kv_lora_rank - self.v_head_dim = config.v_head_dim - self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim - - self.scale = self.q_head_dim**-0.5 - - if self.q_lora_rank is None: - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.q_head_dim, bias=False - ) - else: - self.q_a_proj = nn.Linear( - self.hidden_size, self.q_lora_rank, bias=config.attention_bias - ) - self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank) - self.q_b_proj = nn.Linear( - self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False - ) - - self.kv_a_proj_with_mqa = nn.Linear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=config.attention_bias, - ) - self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank) - self.kv_b_proj = nn.Linear( - self.kv_lora_rank, - self.num_heads - * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), - bias=False, - ) - - self.o_proj = nn.Linear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=config.attention_bias, - ) - - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scale = self.scale * mscale * mscale - - rope_kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - self.rope = DeepseekV2YarnRotaryEmbedding( - dim=self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - **rope_kwargs, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - if self.q_lora_rank is None: - q = self.q_proj(x) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) - - q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) - q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) - compressed_kv = self.kv_a_proj_with_mqa(x) - compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) - k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) - kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) - - k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) - - if cache is not None: - q_pe = self.rope(q_pe, cache.offset) - k_pe = self.rope(k_pe, cache.offset) - k_pe = mx.repeat(k_pe, self.num_heads, axis=1) - keys, values = cache.update_and_fetch( - mx.concatenate([k_nope, k_pe], axis=-1), values - ) - else: - q_pe = self.rope(q_pe) - k_pe = self.rope(k_pe) - k_pe = mx.repeat(k_pe, self.num_heads, axis=1) - keys = mx.concatenate([k_nope, k_pe], axis=-1) - - queries = mx.concatenate([q_nope, q_pe], axis=-1) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class DeepseekV2MLP(nn.Module): - def __init__( - self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None - ): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = ( - config.intermediate_size if intermediate_size is None else intermediate_size - ) - - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - - def __call__(self, x): - down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class MoEGate(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.topk_method = config.topk_method - self.n_group = config.n_group - self.topk_group = config.topk_group - self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) - - def __call__(self, x): - gates = x @ self.weight.T - - scores = mx.softmax(gates, axis=-1, precise=True) - - if self.topk_method == "group_limited_greedy": - bsz, seq_len = x.shape[:2] - scores = scores.reshape(bsz, seq_len, self.n_group, -1) - group_scores = scores.max(axis=-1, keepdims=True) - k = self.n_group - self.topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] - scores = mx.put_along_axis( - scores, group_idx, mx.array(0.0, scores.dtype), axis=-2 - ) - scores = scores.reshape(bsz, seq_len, -1) - - k = self.top_k - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(scores, inds, axis=-1) - scores = scores * self.routed_scaling_factor - - return inds, scores - - -class DeepseekV2MoE(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.num_experts_per_tok = config.num_experts_per_tok - self.switch_mlp = SwitchGLU( - config.hidden_size, config.moe_intermediate_size, config.n_routed_experts - ) - - self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP( - config=config, intermediate_size=intermediate_size - ) - - def __call__(self, x): - inds, scores = self.gate(x) - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(x) - - return y - - -class DeepseekV2DecoderLayer(nn.Module): - def __init__(self, config: ModelArgs, layer_idx: int): - super().__init__() - self.self_attn = DeepseekV2Attention(config) - self.mlp = ( - DeepseekV2MoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV2MLP(config) - ) - self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class DeepseekV2Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [ - DeepseekV2DecoderLayer(config, idx) - for idx in range(config.num_hidden_layers) - ] - self.start_idx = 0 - self.end_idx = len(self.layers) - self.num_layers = self.end_idx - - self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.pipeline_rank = 0 - self.pipeline_size = 1 - - def pipeline(self, group): - # Split layers in reverse so rank=0 gets the last layers and - # rank=pipeline_size-1 gets the first - self.pipeline_rank = group.rank() - self.pipeline_size = group.size() - layers_per_rank = len(self.layers) // self.pipeline_size - extra = len(self.layers) - layers_per_rank * self.pipeline_size - if self.pipeline_rank < extra: - layers_per_rank += 1 - - self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank - self.end_idx = self.start_idx + layers_per_rank - self.num_layers = layers_per_rank - self.layers = self.layers[: self.end_idx] - self.layers[: self.start_idx] = [None] * self.start_idx - self.num_layers = len(self.layers) - self.start_idx - - def __call__( - self, - x: mx.array, - cache: Optional[Any] = None, - mask: Optional[mx.array] = None, - ) -> mx.array: - h = self.embed_tokens(x) - - pipeline_rank = self.pipeline_rank - pipeline_size = self.pipeline_size - # Hack to avoid time-outs during prompt-processing - dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * self.num_layers - - # Receive from the previous process in the pipeline - if pipeline_rank < pipeline_size - 1: - h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) - - for i in range(self.num_layers): - h = self.layers[self.start_idx + i](h, mask, cache[i]) - - # Send to the next process in the pipeline - if pipeline_rank != 0: - h = mx.distributed.send( - h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream - ) - - # Broadcast h while keeping it in the graph - h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]] - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.args = config - self.model_type = config.model_type - self.model = DeepseekV2Model(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache: Optional[Any] = None, - mask: Optional[mx.array] = None, - ): - out = self.model(inputs, cache, mask) - return self.lm_head(out) - - def sanitize(self, weights): - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") - for e in range(self.args.n_routed_experts) - ] - weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) - return weights - - @property - def layers(self): - return self.model.layers[self.model.start_idx : self.model.end_idx] diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py deleted file mode 100644 index 5cd40a0d..00000000 --- a/llms/mlx_lm/models/deepseek_v3.py +++ /dev/null @@ -1,506 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import math -from dataclasses import dataclass -from functools import partial -from typing import Any, Dict, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "deepseek_v3" - vocab_size: int = 102400 - hidden_size: int = 4096 - intermediate_size: int = 11008 - moe_intermediate_size: int = 1407 - num_hidden_layers: int = 30 - num_attention_heads: int = 32 - num_key_value_heads: int = 32 - n_shared_experts: Optional[int] = None - n_routed_experts: Optional[int] = None - routed_scaling_factor: float = 1.0 - kv_lora_rank: int = 512 - q_lora_rank: int = 1536 - qk_rope_head_dim: int = 64 - v_head_dim: int = 128 - qk_nope_head_dim: int = 128 - topk_method: str = "noaux_tc" - scoring_func: str = "sigmoid" - norm_topk_prob: bool = True - n_group: Optional[int] = None - topk_group: Optional[int] = None - num_experts_per_tok: Optional[int] = None - moe_layer_freq: int = 1 - first_k_dense_replace: int = 0 - max_position_embeddings: int = 2048 - rms_norm_eps: float = 1e-6 - rope_theta: float = 10000.0 - rope_scaling: Dict = None - attention_bias: bool = False - - -def yarn_find_correction_dim( - num_rotations, dim, base=10000, max_position_embeddings=2048 -): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) - - -def yarn_find_correction_range( - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 -): - low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) - - -def yarn_get_mscale(scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -def yarn_linear_ramp_mask(min_val, max_val, dim): - if min_val == max_val: - max_val += 0.001 # Prevent singularity - - linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) - return mx.clip(linear_func, 0, 1) - - -class DeepseekV3YarnRotaryEmbedding(nn.Module): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - scaling_factor=1.0, - original_max_position_embeddings=4096, - beta_fast=32, - beta_slow=1, - mscale=1, - mscale_all_dim=0, - ): - super().__init__() - self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( - scaling_factor, mscale_all_dim - ) - freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) - freq_inter = scaling_factor * base ** ( - mx.arange(0, dim, 2, dtype=mx.float32) / dim - ) - low, high = yarn_find_correction_range( - beta_fast, - beta_slow, - dim, - base, - original_max_position_embeddings, - ) - freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) - self._freqs = (freq_inter * freq_extra) / ( - freq_inter * freq_mask + freq_extra * (1 - freq_mask) - ) - - def __call__(self, x, offset=0): - if self.mscale != 1.0: - x = self.mscale * x - return mx.fast.rope( - x, - x.shape[-1], - traditional=True, - base=None, - scale=1.0, - offset=offset, - freqs=self._freqs, - ) - - -# A clipped silu to prevent fp16 from overflowing -@partial(mx.compile, shapeless=True) -def clipped_silu(x): - return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100) - - -class DeepseekV3Attention(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.q_lora_rank = config.q_lora_rank - self.qk_rope_head_dim = config.qk_rope_head_dim - self.kv_lora_rank = config.kv_lora_rank - self.v_head_dim = config.v_head_dim - self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim - - self.scale = self.q_head_dim**-0.5 - - if self.q_lora_rank is None: - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.q_head_dim, bias=False - ) - else: - self.q_a_proj = nn.Linear( - self.hidden_size, self.q_lora_rank, bias=config.attention_bias - ) - self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank) - self.q_b_proj = nn.Linear( - self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False - ) - - self.kv_a_proj_with_mqa = nn.Linear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=config.attention_bias, - ) - self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank) - self.kv_b_proj = nn.Linear( - self.kv_lora_rank, - self.num_heads - * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), - bias=False, - ) - - self.o_proj = nn.Linear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=config.attention_bias, - ) - - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scale = self.scale * mscale * mscale - - rope_kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - self.rope = DeepseekV3YarnRotaryEmbedding( - dim=self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - **rope_kwargs, - ) - else: - self.rope = nn.RoPE( - dims=self.qk_rope_head_dim, - base=self.rope_theta, - traditional=True, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - if self.q_lora_rank is None: - q = self.q_proj(x) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) - - q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) - q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) - compressed_kv = self.kv_a_proj_with_mqa(x) - compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) - k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) - kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) - - k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) - - if cache is not None: - q_pe = self.rope(q_pe, cache.offset) - k_pe = self.rope(k_pe, cache.offset) - k_pe = mx.repeat(k_pe, self.num_heads, axis=1) - keys, values = cache.update_and_fetch( - mx.concatenate([k_nope, k_pe], axis=-1), values - ) - else: - q_pe = self.rope(q_pe) - k_pe = self.rope(k_pe) - k_pe = mx.repeat(k_pe, self.num_heads, axis=1) - keys = mx.concatenate([k_nope, k_pe], axis=-1) - - queries = mx.concatenate([q_nope, q_pe], axis=-1) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class DeepseekV3MLP(nn.Module): - def __init__( - self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None - ): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = ( - config.intermediate_size if intermediate_size is None else intermediate_size - ) - - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - - def __call__(self, x): - down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -@mx.compile -def group_expert_select( - gates, - e_score_correction_bias, - top_k, - n_group, - topk_group, - routed_scaling_factor, - norm_topk_prob, -): - - k = top_k - scores = mx.sigmoid(gates.astype(mx.float32)) - scores = scores + e_score_correction_bias - scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1)) - group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True) - k = n_group - topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] - scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2) - scores = mx.flatten(scores, -2, -1) - - k = top_k - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(scores, inds, axis=-1) - if top_k > 1 and norm_topk_prob: - denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 - scores = scores / denominator - scores = scores * routed_scaling_factor - - return inds, scores - - -class MoEGate(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group - self.topk_group = config.topk_group - self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) - self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) - assert config.topk_method == "noaux_tc", "Unsupported topk method." - - def __call__(self, x): - return group_expert_select( - x @ self.weight.T, - self.e_score_correction_bias, - self.top_k, - self.n_group, - self.topk_group, - self.routed_scaling_factor, - self.norm_topk_prob, - ) - - -class DeepseekV3MoE(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.num_experts_per_tok = config.num_experts_per_tok - self.switch_mlp = SwitchGLU( - config.hidden_size, - config.moe_intermediate_size, - config.n_routed_experts, - activation=clipped_silu, - ) - - self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV3MLP( - config=config, intermediate_size=intermediate_size - ) - - def __call__(self, x): - inds, scores = self.gate(x) - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(x) - - return y - - -class DeepseekV3DecoderLayer(nn.Module): - def __init__(self, config: ModelArgs, layer_idx: int): - super().__init__() - self.self_attn = DeepseekV3Attention(config) - self.mlp = ( - DeepseekV3MoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV3MLP(config) - ) - self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - return h + r - - -class DeepseekV3Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [ - DeepseekV3DecoderLayer(config, idx) - for idx in range(config.num_hidden_layers) - ] - self.start_idx = 0 - self.end_idx = len(self.layers) - self.num_layers = self.end_idx - - self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pipeline_rank = 0 - self.pipeline_size = 1 - - def pipeline(self, group): - # Split layers in reverse so rank=0 gets the last layers and - # rank=pipeline_size-1 gets the first - self.pipeline_rank = group.rank() - self.pipeline_size = group.size() - layers_per_rank = len(self.layers) // self.pipeline_size - extra = len(self.layers) - layers_per_rank * self.pipeline_size - if self.pipeline_rank < extra: - layers_per_rank += 1 - self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank - self.end_idx = self.start_idx + layers_per_rank - self.layers = self.layers[: self.end_idx] - self.layers[: self.start_idx] = [None] * self.start_idx - self.num_layers = len(self.layers) - self.start_idx - - def __call__( - self, - x: mx.array, - cache: Optional[Any] = None, - mask: Optional[mx.array] = None, - ) -> mx.array: - h = self.embed_tokens(x) - - pipeline_rank = self.pipeline_rank - pipeline_size = self.pipeline_size - # Hack to avoid time-outs during prompt-processing - dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * self.num_layers - - # Receive from the previous process in the pipeline - - if pipeline_rank < pipeline_size - 1: - h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) - - for i in range(self.num_layers): - h = self.layers[self.start_idx + i](h, mask, cache[i]) - - # Send to the next process in the pipeline - if pipeline_rank != 0: - h = mx.distributed.send( - h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream - ) - - # Broadcast h while keeping it in the graph - h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]] - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.args = config - self.model_type = config.model_type - self.model = DeepseekV3Model(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache: Optional[Any] = None, - mask: Optional[mx.array] = None, - ): - out = self.model(inputs, cache, mask) - return self.lm_head(out) - - def sanitize(self, weights): - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") - for e in range(self.args.n_routed_experts) - ] - weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) - - # Remove multi-token prediction layer and any unused precomputed rotary freqs - return { - k: v - for k, v in weights.items() - if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k - } - - @property - def layers(self): - return self.model.layers[self.model.start_idx : self.model.end_idx] diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py deleted file mode 100644 index ee3ed1e8..00000000 --- a/llms/mlx_lm/models/exaone.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright © 2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .rope_utils import initialize_rope - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_layers: int - intermediate_size: int - num_attention_heads: int - vocab_size: int - rope_theta: float - layer_norm_epsilon: float - num_key_value_heads: int - head_dim: Optional[int] = None - max_position_embeddings: Optional[int] = None - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = True - attention_bias: bool = False - mlp_bias: bool = False - - -class AttentionModule(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.head_dim = head_dim = args.head_dim or (dim // n_heads) - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) - self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) - - self.rope = initialize_rope( - self.head_dim, - args.rope_theta, - args.rope_traditional, - args.rope_scaling, - args.max_position_embeddings, - ) - - def __call__( - self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None - ) -> mx.array: - B, L, D = x.shape - q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - q = self.rope(q, offset=cache.offset) - k = self.rope(k, offset=cache.offset) - k, v = cache.update_and_fetch(k, v) - else: - q = self.rope(q) - k = self.rope(k) - - out = scaled_dot_product_attention( - q, k, v, cache=cache, scale=self.scale, mask=mask - ) - out = out.transpose(0, 2, 1, 3).reshape(B, L, D) - return self.out_proj(out) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.attention = AttentionModule(args) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - dim = args.hidden_size - hidden_dim = args.intermediate_size - self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) - self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) - self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias) - - def __call__(self, x: mx.array) -> mx.array: - return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - self.attn = Attention(args) - self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - self.mlp = MLP(args) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - h = x + self.attn.attention(self.ln_1(x), mask, cache) - out = h + self.mlp(self.ln_2(h)) - return out - - -class ExaoneModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.wte = nn.Embedding(args.vocab_size, args.hidden_size) - self.h = [TransformerBlock(args) for _ in range(args.num_layers)] - self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.wte(inputs) - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.h) - - for layer, c in zip(self.h, cache): - h = layer(h, mask, cache=c) - - return self.ln_f(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.transformer = ExaoneModel(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.transformer(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.transformer.wte.as_linear(out) - else: - out = self.lm_head(out) - return out - - @property - def layers(self): - return self.transformer.h diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py deleted file mode 100644 index 0860ddeb..00000000 --- a/llms/mlx_lm/models/gemma.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - head_dim: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int - rope_theta: float = 10000 - rope_traditional: bool = False - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def __call__(self, x): - return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.head_dim = head_dim = args.head_dim - - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class GemmaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - h = h * (self.args.hidden_size**0.5) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = GemmaModel(args) - self.args = args - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - out = self.model.embed_tokens.as_linear(out) - return out - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py deleted file mode 100644 index 321a58ff..00000000 --- a/llms/mlx_lm/models/gemma2.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - head_dim: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int - rope_theta: float = 10000 - rope_traditional: bool = False - attn_logit_softcapping: float = 50.0 - final_logit_softcapping: float = 30.0 - query_pre_attn_scalar: float = 144.0 - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def __call__(self, x): - return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.repeats = n_heads // n_kv_heads - self.head_dim = head_dim = args.head_dim - - self.scale = 1.0 / (args.query_pre_attn_scalar**0.5) - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - self.attn_logit_softcapping = args.attn_logit_softcapping - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - queries = queries * self.scale - - if self.repeats > 1: - queries = queries.reshape( - B, self.n_kv_heads, self.repeats, L, self.head_dim - ) - keys = mx.expand_dims(keys, 2) - values = mx.expand_dims(values, 2) - - scores = queries @ keys.swapaxes(-1, -2) - scores = mx.tanh(scores / self.attn_logit_softcapping) - scores *= self.attn_logit_softcapping - - if mask is not None: - scores = scores + mask - scores = mx.softmax(scores, precise=True, axis=-1) - output = scores @ values - if self.repeats > 1: - output = output.reshape(B, self.n_heads, L, self.head_dim) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.pre_feedforward_layernorm = RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.post_feedforward_layernorm = RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + self.post_attention_layernorm(r) - r = self.mlp(self.pre_feedforward_layernorm(h)) - out = h + self.post_feedforward_layernorm(r) - return out - - -class GemmaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - h = h * (self.args.hidden_size**0.5) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.final_logit_softcapping = args.final_logit_softcapping - self.model = GemmaModel(args) - self.args = args - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - out = self.model.embed_tokens.as_linear(out) - out = mx.tanh(out / self.final_logit_softcapping) - out = out * self.final_logit_softcapping - return out - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/gemma3_text.py b/llms/mlx_lm/models/gemma3_text.py deleted file mode 100644 index be71f461..00000000 --- a/llms/mlx_lm/models/gemma3_text.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright © 2025 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Optional - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask -from .cache import KVCache, RotatingKVCache - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int = 1152 - num_hidden_layers: int = 26 - intermediate_size: int = 6912 - num_attention_heads: int = 4 - head_dim: int = 256 - rms_norm_eps: float = 1.0e-6 - vocab_size: int = 262144 - num_key_value_heads: int = 1 - rope_global_base_freq: float = 1_000_000.0 - rope_local_base_freq: float = 10_000.0 - rope_traditional: bool = False - query_pre_attn_scalar: float = 256 - sliding_window: int = 512 - sliding_window_pattern: int = 6 - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.repeats = n_heads // n_kv_heads - self.head_dim = head_dim = args.head_dim - self.layer_idx = layer_idx - - self.scale = args.query_pre_attn_scalar**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - - self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps) - self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps) - self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern != 0 - - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=( - args.rope_local_base_freq - if self.is_sliding - else args.rope_global_base_freq - ), - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, _ = x.shape - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - queries = self.q_norm(queries) - keys = self.k_norm(keys) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - # Sliding window - if mask is not None and mask.shape[-1] != keys.shape[-2]: - mask = mask[..., -keys.shape[-2] :] - - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def __call__(self, x): - return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args, layer_idx) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.pre_feedforward_layernorm = RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.post_feedforward_layernorm = RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + self.post_attention_layernorm(r) - r = self.mlp(self.pre_feedforward_layernorm(h)) - out = h + self.post_feedforward_layernorm(r) - return out - - -class Gemma3Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args, layer_idx=layer_idx) - for layer_idx in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - - h = self.embed_tokens(inputs) - h *= mx.array(self.args.hidden_size**0.5, mx.bfloat16).astype(h.dtype) - - if cache is None: - cache = [None] * len(self.layers) - - if mask is None: - j = self.args.sliding_window_pattern - full_mask = create_attention_mask(h, cache[j - 1 : j]) - sliding_window_mask = create_attention_mask(h, cache) - - for i, (layer, c) in enumerate(zip(self.layers, cache)): - is_sliding = ( - i % self.args.sliding_window_pattern - == self.args.sliding_window_pattern - 1 - ) - - if mask is None and is_sliding: - mask = sliding_window_mask - elif mask is None: - mask = full_mask - - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = Gemma3Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - mask: Optional[mx.array] = None, - ): - out = self.model(inputs, mask, cache) - out = self.lm_head(out) - return out - - def sanitize(self, weights): - if "lm_head.weight" not in weights: - weights["lm_head.weight"] = weights["model.embed_tokens.weight"] - return { - k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k - } - - @property - def layers(self): - return self.model.layers - - def make_cache(self): - caches = [] - for i in range(self.args.num_hidden_layers): - if ( - i % self.args.sliding_window_pattern - == self.args.sliding_window_pattern - 1 - ): - caches.append(KVCache()) - else: - caches.append( - RotatingKVCache(max_size=self.args.sliding_window, keep=0) - ) - return caches diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py deleted file mode 100644 index 5b277734..00000000 --- a/llms/mlx_lm/models/gpt2.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - n_ctx: int - n_embd: int - n_head: int - n_layer: int - n_positions: int - layer_norm_epsilon: float - vocab_size: int - num_key_value_heads: int = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.n_head - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head" - - self.n_embd = args.n_embd - self.n_head = args.n_head - self.head_dim = self.n_embd // self.n_head - - self.scale = self.head_dim**-0.5 - - self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True) - self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - qkv = self.c_attn(x) - queries, keys, values = mx.split(qkv, 3, axis=-1) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3) - - if cache is not None: - keys, values = cache.update_and_fetch(keys, values) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.c_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.n_embd = args.n_embd - self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd) - self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd) - - def __call__(self, x) -> mx.array: - return self.c_proj(nn.gelu_approx(self.c_fc(x))) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.n_head = args.n_head - self.n_embd = args.n_embd - self.layer_norm_epsilon = args.layer_norm_epsilon - self.attn = Attention(args) - self.mlp = MLP(args) - self.ln_1 = nn.LayerNorm( - self.n_embd, - eps=self.layer_norm_epsilon, - ) - self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.attn(self.ln_1(x), mask, cache) - h = x + r - r = self.mlp(self.ln_2(h)) - out = h + r - return out - - -class GPT2Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_embd = args.n_embd - self.n_positions = args.n_positions - self.vocab_size = args.vocab_size - self.n_layer = args.n_layer - self.layer_norm_epsilon = args.layer_norm_epsilon - assert self.vocab_size > 0 - self.wte = nn.Embedding(self.vocab_size, self.n_embd) - self.wpe = nn.Embedding(self.n_positions, self.n_embd) - self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)] - self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - _, L = inputs.shape - - hidden_states = self.wte(inputs) - - mask = None - if hidden_states.shape[1] > 1: - - position_ids = mx.array(np.arange(L)) - hidden_states += self.wpe(position_ids) - - if mask is None: - mask = create_attention_mask(hidden_states, cache) - - if cache is None: - cache = [None] * len(self.h) - - for layer, c in zip(self.h, cache): - hidden_states = layer(hidden_states, mask, cache=c) - - return self.ln_f(hidden_states) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = GPT2Model(args) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - out = self.model.wte.as_linear(out) - return out - - def sanitize(self, weights): - new_weights = {} - for i in range(self.args.n_layer): - if f"h.{i}.attn.bias" in weights: - del weights[f"h.{i}.attn.bias"] - if f"h.{i}.attn.c_attn.weight" in weights: - weights[f"h.{i}.attn.c_attn.weight"] = weights[ - f"h.{i}.attn.c_attn.weight" - ].transpose(1, 0) - if f"h.{i}.attn.c_proj.weight" in weights: - weights[f"h.{i}.attn.c_proj.weight"] = weights[ - f"h.{i}.attn.c_proj.weight" - ].transpose(1, 0) - if f"h.{i}.mlp.c_fc.weight" in weights: - weights[f"h.{i}.mlp.c_fc.weight"] = weights[ - f"h.{i}.mlp.c_fc.weight" - ].transpose(1, 0) - if f"h.{i}.mlp.c_proj.weight" in weights: - weights[f"h.{i}.mlp.c_proj.weight"] = weights[ - f"h.{i}.mlp.c_proj.weight" - ].transpose(1, 0) - for weight in weights: - if not weight.startswith("model."): - new_weights[f"model.{weight}"] = weights[weight] - else: - new_weights[weight] = weights[weight] - return new_weights - - @property - def layers(self): - return self.model.h diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py deleted file mode 100644 index 1d9794b6..00000000 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - n_embd: int - n_layer: int - n_inner: int - n_head: int - n_positions: int - layer_norm_epsilon: float - vocab_size: int - num_key_value_heads: int = None - multi_query: bool = True - attention_bias: bool = True - mlp_bias: bool = True - tie_word_embeddings: bool = True - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = 1 if self.multi_query else self.n_head - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.dim = dim = args.n_embd - self.n_heads = n_heads = args.n_head - self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head - - self.head_dim = head_dim = dim // n_heads - - self.kv_dim = n_kv_heads * head_dim - - self.scale = head_dim**-0.5 - - if hasattr(args, "attention_bias"): - attention_bias = args.attention_bias - else: - attention_bias = False - - self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias) - self.c_proj = nn.Linear(dim, dim, bias=attention_bias) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - qkv = self.c_attn(x) - queries, keys, values = mx.split( - qkv, [self.dim, self.dim + self.kv_dim], axis=-1 - ) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - keys, values = cache.update_and_fetch(keys, values) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.c_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.n_embd - hidden_dim = args.n_inner - if hasattr(args, "mlp_bias"): - mlp_bias = args.mlp_bias - else: - mlp_bias = False - - self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias) - self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) - - def __call__(self, x) -> mx.array: - return self.c_proj(nn.gelu(self.c_fc(x))) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_head = args.n_head - self.n_embd = args.n_embd - self.attn = Attention(args) - self.mlp = MLP(args) - self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) - self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.attn(self.ln_1(x), mask, cache) - h = x + r - r = self.mlp(self.ln_2(h)) - out = h + r - return out - - -class GPTBigCodeModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - assert self.vocab_size > 0 - self.wte = nn.Embedding(args.vocab_size, args.n_embd) - self.wpe = nn.Embedding(args.n_positions, args.n_embd) - self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)] - self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - B, L = inputs.shape - - hidden_states = self.wte(inputs) - - mask = None - if mask is not None and hidden_states.shape[1] > 1: - mask = create_attention_mask(hidden_states, cache) - - if cache is None: - cache = [None] * len(self.h) - position_ids = mx.array(np.arange(L)) - else: - position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L)) - - hidden_states += self.wpe(position_ids) - - for layer, c in zip(self.h, cache): - hidden_states = layer(hidden_states, mask, cache=c) - - return self.ln_f(hidden_states) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.transformer = GPTBigCodeModel(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.transformer(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.transformer.wte.as_linear(out) - else: - out = self.lm_head(out) - return out - - @property - def layers(self): - return self.transformer.h diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py deleted file mode 100644 index 5e124a67..00000000 --- a/llms/mlx_lm/models/gpt_neox.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - -# Based on the transformers implementation at: -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - max_position_embeddings: int - hidden_size: int - num_attention_heads: int - num_hidden_layers: int - layer_norm_eps: float - vocab_size: int - rotary_emb_base: int - rotary_pct: float - num_key_value_heads: int = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - assert ( - args.hidden_size % args.num_attention_heads == 0 - ), "hidden_size must be divisible by num_attention_heads" - - self.hidden_size = args.hidden_size - self.num_attention_heads = args.num_attention_heads - self.head_dim = self.hidden_size // self.num_attention_heads - - self.rope = nn.RoPE( - dims=int(self.head_dim * args.rotary_pct), - traditional=False, - base=args.rotary_emb_base, - ) - - self.scale = self.head_dim**-0.5 - - self.query_key_value = nn.Linear( - self.hidden_size, 3 * self.hidden_size, bias=True - ) - self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - qkv = self.query_key_value(x) - - new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim) - qkv = qkv.reshape(*new_qkv_shape) - - queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)] - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.dense(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.hidden_size = args.hidden_size - self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size) - self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size) - - def __call__(self, x) -> mx.array: - # gelu_approx corresponds to FastGELUActivation in transformers. - return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x))) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.hidden_size = args.hidden_size - self.layer_norm_eps = args.layer_norm_eps - self.attention = Attention(args) - self.mlp = MLP(args) - self.input_layernorm = nn.LayerNorm( - self.hidden_size, - eps=self.layer_norm_eps, - ) - self.post_attention_layernorm = nn.LayerNorm( - self.hidden_size, eps=self.layer_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - residual = x - # NeoX runs attention and feedforward network in parallel. - attn = self.attention(self.input_layernorm(x), mask, cache) - ffn = self.mlp(self.post_attention_layernorm(x)) - out = attn + ffn + residual - return out - - -class GPTNeoXModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - self.layer_norm_eps = args.layer_norm_eps - assert self.vocab_size > 0 - self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size) - self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)] - self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - _, L = inputs.shape - - hidden_states = self.embed_in(inputs) - - if mask is None: - mask = create_attention_mask(hidden_states, cache) - - if cache is None: - cache = [None] * len(self.h) - - for layer, c in zip(self.h, cache): - hidden_states = layer(hidden_states, mask, cache=c) - - out = self.final_layer_norm(hidden_states) - out = self.embed_out(out) - - return out - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = GPTNeoXModel(args) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - return out - - def sanitize(self, weights): - new_weights = {} - - for w_key, w_value in weights.items(): - # Created through register_buffer in Pytorch, not needed here. - ignore_suffixes = [ - ".attention.bias", - ".attention.masked_bias", - ".attention.rotary_emb.inv_freq", - ] - - skip_weight = False - for ignored_suffix in ignore_suffixes: - if w_key.endswith(ignored_suffix): - skip_weight = True - break - - if skip_weight: - continue - - if not w_key.startswith("model."): - w_key = f"model.{w_key}" - - w_key = w_key.replace(".gpt_neox.layers.", ".h.") - w_key = w_key.replace(".gpt_neox.", ".") - - new_weights[w_key] = w_value - - return new_weights - - @property - def layers(self): - return self.model.h diff --git a/llms/mlx_lm/models/granite.py b/llms/mlx_lm/models/granite.py deleted file mode 100644 index 43597d99..00000000 --- a/llms/mlx_lm/models/granite.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .rope_utils import initialize_rope - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - logits_scaling: float - attention_multiplier: float - embedding_multiplier: float - residual_multiplier: float - max_position_embeddings: int - num_key_value_heads: int - attention_bias: bool - mlp_bias: bool - rope_theta: float - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = True - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.head_dim = head_dim = args.hidden_size // n_heads - - self.scale = args.attention_multiplier - attention_bias = args.attention_bias - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - - self.rope = initialize_rope( - self.head_dim, - args.rope_theta, - False, - args.rope_scaling, - args.max_position_embeddings, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - hidden_dim = args.intermediate_size - if hasattr(args, "mlp_bias"): - mlp_bias = args.mlp_bias - else: - mlp_bias = False - - self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) - self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) - self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.residual_multiplier = args.residual_multiplier - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r * self.residual_multiplier - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r * self.residual_multiplier - return out - - -class GraniteModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.embedding_multiplier = args.embedding_multiplier - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) * self.embedding_multiplier - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, cache=c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = GraniteModel(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - self.logits_scaling = args.logits_scaling - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out / self.logits_scaling - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py deleted file mode 100644 index ff551bca..00000000 --- a/llms/mlx_lm/models/helium.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright © 2025 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - num_key_value_heads: int - rms_norm_eps: float - vocab_size: int - attention_bias: bool - head_dim: int - max_position_embeddings: int - mlp_bias: bool - model_type: str - rope_theta: float - tie_word_embeddings: bool - - -class HeliumAttention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class HeliumMLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - self.intermediate_size = args.intermediate_size - - self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=args.mlp_bias - ) - self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=args.mlp_bias - ) - self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias=args.mlp_bias - ) - - def __call__(self, x: mx.array) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class HeliumDecoderLayer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - - self.self_attn = HeliumAttention(args) - self.mlp = HeliumMLP(args) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class HeliumModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_hidden_layers = args.num_hidden_layers - self.vocab_size = args.vocab_size - - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - - self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)] - - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ) -> mx.array: - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - - self.model = HeliumModel(args) - - self.vocab_size = args.vocab_size - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ) -> mx.array: - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py deleted file mode 100644 index 122cebda..00000000 --- a/llms/mlx_lm/models/hunyuan.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - num_key_value_heads: int - attention_bias: bool - moe_topk: int - num_experts: int - num_shared_expert: int - use_mixed_mlp_moe: bool - use_qk_norm: bool - rms_norm_eps: float - rope_theta: float - use_cla: bool - cla_share_factor: 2 - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = False - - def __post_init__(self): - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - -class DynamicNTKAlphaRoPE(nn.Module): - def __init__( - self, - dims: int, - base: float = 10000, - scaling_alpha: float = 1.0, - ): - super().__init__() - self.dims = dims - base = base * scaling_alpha ** (dims / (dims - 2)) - self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims) - - def __call__(self, x, offset: int = 0): - return mx.fast.rope( - x, - self.dims, - traditional=False, - base=None, - scale=1.0, - offset=offset, - freqs=self._freqs, - ) - - -class Attention(nn.Module): - def __init__(self, kv_proj: bool, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) - if kv_proj: - self.k_proj = nn.Linear( - dim, n_kv_heads * head_dim, bias=args.attention_bias - ) - self.v_proj = nn.Linear( - dim, n_kv_heads * head_dim, bias=args.attention_bias - ) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) - self.use_qk_norm = args.use_qk_norm - if self.use_qk_norm: - self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) - self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) - - self.rope = DynamicNTKAlphaRoPE( - head_dim, - base=args.rope_theta, - scaling_alpha=args.rope_scaling["alpha"], - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - kv_states=None, - ) -> mx.array: - B, L, D = x.shape - - queries = self.q_proj(x) - if kv_states is None: - keys, values = self.k_proj(x), self.v_proj(x) - kv_states = keys, values - else: - keys, values = kv_states - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - offset = cache.offset if cache else 0 - queries = self.rope(queries, offset=offset) - keys = self.rope(keys, offset=offset) - if self.use_qk_norm: - queries = self.query_layernorm(queries) - keys = self.key_layernorm(keys) - - if cache is not None: - keys, values = cache.update_and_fetch(keys, values) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), kv_states - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class Gate(nn.Module): - def __init__(self, dim, num_experts): - super().__init__() - self.wg = nn.Linear(dim, num_experts, bias=False) - - def __call__(self, x) -> mx.array: - return self.wg(x) - - -class MoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - dim = args.hidden_size - intermediate_size = args.intermediate_size - self.use_shared_mlp = args.use_mixed_mlp_moe - - if args.use_mixed_mlp_moe: - self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert) - - self.num_experts = num_experts = args.num_experts - self.top_k = args.moe_topk - - self.gate = Gate(dim, num_experts) - self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) - - def __call__( - self, - x: mx.array, - ): - gates = self.gate(x) - gates = mx.softmax(gates, axis=-1, precise=True) - - k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) - scores = mx.take_along_axis(gates, inds, axis=-1) - - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - - if self.use_shared_mlp: - shared_expert_output = self.shared_mlp(x) - y = y + shared_expert_output - - return y - - -class DecoderLayer(nn.Module): - def __init__(self, args: ModelArgs, kv_proj: bool): - super().__init__() - self.hidden_size = args.hidden_size - self.self_attn = Attention(kv_proj, args) - if args.num_experts == 1: - self.mlp = MLP(args.hidden_size, args.intermediate_size) - else: - self.mlp = MoeBlock(args) - - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None, - ): - r, shared_kv_states = self.self_attn( - self.input_layernorm(x), mask, cache, shared_kv_states - ) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out, shared_kv_states - - -class HunYuanModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - DecoderLayer( - args=args, - kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0, - ) - for i in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for i, (layer, c) in enumerate(zip(self.layers, cache)): - if (not self.args.use_cla) or i % self.args.cla_share_factor == 0: - shared_kv_states = None - h, shared_kv_states = layer(h, mask, c, shared_kv_states) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = HunYuanModel(args) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - return self.model.embed_tokens.as_linear(out) - - def sanitize(self, weights): - - if "model.layers.0.mlp.gate_and_up_proj.weight" in weights: - new_weights = {} - D = self.args.hidden_size - n_kv_heads = self.args.num_key_value_heads - n_kv_groups = self.args.num_attention_heads // n_kv_heads - head_dim = D // self.args.num_attention_heads - for k, v in weights.items(): - if "qkv_proj" in k: - v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1) - splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1) - for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits): - k_new = k.replace("qkv_proj", k_up) - new_weights[k_new] = mx.flatten(v_new, 0, 2) - elif "gate_and_up_proj" in k: - splits = v.split(2, axis=0) - for k_up, v_new in zip(["up_proj", "gate_proj"], splits): - k_new = k.replace("gate_and_up_proj", k_up) - new_weights[k_new] = v_new - else: - new_weights[k] = v - weights = new_weights - - if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: - return weights - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for n in ["up_proj", "down_proj", "gate_proj"]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") - for e in range(self.args.num_experts) - ] - weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) - return weights - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py deleted file mode 100644 index 28a095e1..00000000 --- a/llms/mlx_lm/models/internlm2.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - bias: bool = True - max_position_embeddings: int = 32768 - num_key_value_heads: int = None - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = False - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] not in ["linear", "dynamic"]: - raise ValueError( - "rope_scaling 'type' currently only supports 'linear' or 'dynamic" - ) - - -class DynamicNTKScalingRoPE(nn.Module): - """Implements the rotary positional encoding with Dynamic NTK scaling.""" - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scale: float = 1.0, - ): - super().__init__() - self.max_position_embeddings = max_position_embeddings - self.original_base = base - self.dims = dims - self.traditional = traditional - self.scale = scale - - def extra_repr(self): - return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}" - - def __call__(self, x, offset: int = 0): - seq_len = x.shape[1] + offset - if seq_len > self.max_position_embeddings: - base = self.original_base * ( - (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1) - ) ** (self.dims / (self.dims - 2)) - else: - base = self.original_base - - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=base, - scale=self.scale, - offset=offset, - ) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.n_kv_groups = n_heads // args.num_key_value_heads - - self.head_dim = head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.wqkv = nn.Linear( - dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=args.bias - ) - self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias) - - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 2.0 - ) - - self.rope = DynamicNTKScalingRoPE( - head_dim, - max_position_embeddings=args.max_position_embeddings, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - qkv_states = self.wqkv(x) - qkv_states = qkv_states.reshape(B, L, -1, 2 + self.n_kv_groups, self.head_dim) - - queries = qkv_states[..., : self.n_kv_groups, :] - queries = queries.reshape(B, L, -1, self.head_dim) - keys = qkv_states[..., -2, :] - values = qkv_states[..., -1, :] - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.wo(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.w2(nn.silu(self.w1(x)) * self.w3(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.attention = Attention(args) - self.feed_forward = MLP(args.hidden_size, args.intermediate_size) - self.attention_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.attention(self.attention_norm(x), mask, cache) - h = x + r - r = self.feed_forward(self.ffn_norm(h)) - out = h + r - return out - - -class InternLM2Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - assert args.vocab_size > 0 - self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.tok_embeddings(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, cache=c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = InternLM2Model(args) - if not args.tie_word_embeddings: - self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.tok_embeddings.as_linear(out) - else: - out = self.output(out) - return out - - def sanitize(self, weights): - # Remove unused precomputed rotary freqs - return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k} - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/internlm3.py b/llms/mlx_lm/models/internlm3.py deleted file mode 100644 index 3be6f536..00000000 --- a/llms/mlx_lm/models/internlm3.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - bias: bool = False - qkv_bias: bool = False - max_position_embeddings: int = 32768 - num_key_value_heads: int = None - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = False - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "rope_type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]: - raise ValueError( - "rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic" - ) - - -class DynamicNTKScalingRoPE(nn.Module): - """Implements the rotary positional encoding with Dynamic NTK scaling.""" - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scale: float = 1.0, - ): - super().__init__() - self.max_position_embeddings = max_position_embeddings - self.original_base = base - self.dims = dims - self.traditional = traditional - self.scale = scale - - def extra_repr(self): - return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}" - - def __call__(self, x, offset: int = 0): - seq_len = x.shape[1] + offset - if seq_len > self.max_position_embeddings: - base = self.original_base * ( - (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1) - ) ** (self.dims / (self.dims - 2)) - else: - base = self.original_base - - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=base, - scale=self.scale, - offset=offset, - ) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - qkv_bias = args.qkv_bias - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.n_kv_groups = n_heads // args.num_key_value_heads - - self.head_dim = head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias) - - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None - and args.rope_scaling["rope_type"] == "linear" - else 2.0 - ) - - self.rope = DynamicNTKScalingRoPE( - head_dim, - max_position_embeddings=args.max_position_embeddings, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim, bias): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias) - self.down_proj = nn.Linear(hidden_dim, dim, bias=bias) - self.up_proj = nn.Linear(dim, hidden_dim, bias=bias) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class InternLM2Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - assert args.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, cache=c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = InternLM2Model(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - def sanitize(self, weights): - # Remove unused precomputed rotary freqs - return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k} - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py deleted file mode 100644 index 117adf0f..00000000 --- a/llms/mlx_lm/models/llama.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .rope_utils import initialize_rope - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - head_dim: Optional[int] = None - max_position_embeddings: Optional[int] = None - num_key_value_heads: Optional[int] = None - attention_bias: bool = False - mlp_bias: bool = False - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = True - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads - - self.scale = head_dim**-0.5 - if hasattr(args, "attention_bias"): - attention_bias = args.attention_bias - else: - attention_bias = False - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - - self.rope = initialize_rope( - self.head_dim, - args.rope_theta, - args.rope_traditional, - args.rope_scaling, - args.max_position_embeddings, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - hidden_dim = args.intermediate_size - if hasattr(args, "mlp_bias"): - mlp_bias = args.mlp_bias - else: - mlp_bias = False - - self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) - self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) - self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class LlamaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, cache=c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = LlamaModel(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - def sanitize(self, weights): - # Remove unused precomputed rotary freqs - weights = { - k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k - } - if self.args.tie_word_embeddings: - weights.pop("lm_head.weight", None) - return weights - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py deleted file mode 100644 index 93cc616e..00000000 --- a/llms/mlx_lm/models/mamba.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright © 2024-2025 Apple Inc. - -import math -from dataclasses import dataclass - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import MambaCache - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - intermediate_size: int - state_size: int - num_hidden_layers: int - conv_kernel: int - use_bias: bool - use_conv_bias: bool - time_step_rank: int - tie_word_embeddings: bool = True - use_bcdt_rms: bool = False - mixer_rms_eps: float = 1e-6 - - def __post_init__(self): - if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): - self.hidden_size = self.d_model - if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"): - self.intermediate_size = self.d_inner - if not hasattr(self, "state_size") and hasattr(self, "d_state"): - self.state_size = self.d_state - if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"): - self.num_hidden_layers = self.n_layer - if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"): - self.num_hidden_layers = self.n_layers - if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"): - self.conv_kernel = self.d_conv - if not hasattr(self, "use_bias") and hasattr(self, "bias"): - self.use_bias = self.bias - if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"): - self.use_conv_bias = self.conv_bias - - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - if self.model_type == "falcon_mamba": - self.use_bcdt_rms = True - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias=True, padding=0): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.padding = padding - self.weight = mx.random.normal((self.channels, kernel_size, 1)) - self.bias = mx.zeros((channels,)) if bias else None - - def __call__(self, x, cache=None): - B, L, C = x.shape - groups, K, _ = self.weight.shape - - if cache is not None: - x = mx.concatenate([cache, x], axis=1) - else: - x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - - y = mx.conv_general(x, self.weight, groups=groups) - - if self.bias is not None: - y = y + self.bias - - return y, x[:, -K + 1 :, :] - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - self.use_bcdt_rms = args.use_bcdt_rms - if self.use_bcdt_rms: - self.mixer_norm = lambda x: mx.fast.rms_norm( - x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps - ) - - self.in_proj = nn.Linear( - self.hidden_size, self.intermediate_size * 2, bias=args.use_bias - ) - - self.conv1d = DepthWiseConv1d( - channels=self.intermediate_size, - kernel_size=self.conv_kernel_size, - bias=self.use_conv_bias, - padding=self.conv_kernel_size - 1, - ) - - self.x_proj = nn.Linear( - self.intermediate_size, - self.time_step_rank + 2 * self.ssm_state_size, - bias=False, - ) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - A = mx.repeat( - mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]), - repeats=self.intermediate_size, - axis=0, - ) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias=args.use_bias - ) - - def ssm_step(self, x, A, state=None): - D = self.D - deltaBC = self.x_proj(x) - delta, B, C = map( - self.mixer_norm if self.use_bcdt_rms else lambda x: x, - mx.split( - deltaBC, - [self.time_step_rank, self.time_step_rank + self.ssm_state_size], - axis=-1, - ), - ) - if self.use_bcdt_rms: - delta, B, C = map(self.mixer_norm, (delta, B, C)) - delta = nn.softplus(self.dt_proj(delta)) - new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) - if state is not None: - new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) - y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) - y = y + D * x - return y, new_state - - def _process_sequence(self, x, conv_cache, state_cache): - B, T, D = x.shape - xz = self.in_proj(x) - x, z = xz.split(indices_or_sections=2, axis=-1) - - conv_out, new_conv_cache = self.conv1d(x, conv_cache) - x = nn.silu(conv_out) - - A = -mx.exp(self.A_log) - - outputs = [] - current_state = state_cache - y = [] - for t in range(T): - y_t, current_state = self.ssm_step(x[:, t], A, current_state) - y.append(y_t) - y = mx.stack(y, axis=1) - z = self.out_proj(nn.silu(z) * y) - return z, (new_conv_cache, current_state) - - def __call__(self, x, cache): - if cache is None: - conv_cache, state_cache = None, None - else: - conv_cache, state_cache = cache[0], cache[1] - - output, (new_conv_cache, new_state_cache) = self._process_sequence( - x, conv_cache, state_cache - ) - - if isinstance(cache, MambaCache): - cache[0] = new_conv_cache - cache[1] = new_state_cache - - return output - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - return self.mixer(self.norm(x), cache) + x - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - x = self.embeddings(x) - if cache is None: - cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): - x = layer(x, c) - return self.norm_f(x) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - B, T = inputs.shape - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.shape[-1] != 1: - weights[k] = v.moveaxis(2, 1) - return weights - - def make_cache(self): - return [MambaCache() for _ in range(len(self.layers))] - - @property - def layers(self): - return self.backbone.layers diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py deleted file mode 100644 index 7140c577..00000000 --- a/llms/mlx_lm/models/minicpm.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright © 2023-2025 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - dim_model_base: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int - scale_depth: float - scale_emb: float - rope_theta: float = 1000000.0 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[str, float]]] = None - tie_word_embeddings: bool = False - - -class MLP(nn.Module): - def __init__(self, args): - super().__init__() - self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) - self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) - self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) - - def __call__(self, x): - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.num_heads = n_heads = args.num_attention_heads - self.rope_theta = args.rope_theta - - self.head_dim = head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.num_key_value_heads = args.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - - self.rope = nn.RoPE( - dims=self.head_dim, - traditional=args.rope_traditional, - base=self.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ): - B, L, _ = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( - 0, 2, 1, 3 - ) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - attn_output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.o_proj(attn_output) - - -class DecoderLayer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.hidden_size = args.hidden_size - self.num_hidden_layers = args.num_hidden_layers - - self.self_attn = Attention(args) - self.mlp = MLP(args) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - - self.scale_depth = args.scale_depth - self.num_hidden_layers = args.num_hidden_layers - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) - return out - - -class MiniCPMModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - assert self.vocab_size > 0 - - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) * self.args.scale_emb - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = MiniCPMModel(args) - - if not self.args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - - if not self.args.tie_word_embeddings: - out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) - else: - out = out @ self.model.embed_tokens.weight.T - - return out - - def sanitize(self, weights): - if "lm_head.weight" not in weights: - weights["lm_head.weight"] = weights["model.embed_tokens.weight"] - return weights - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py deleted file mode 100644 index 0afd1235..00000000 --- a/llms/mlx_lm/models/mixtral.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int = 32000 - hidden_size: int = 4096 - intermediate_size: int = 14336 - num_hidden_layers: int = 32 - num_attention_heads: int = 32 - num_experts_per_tok: int = 2 - num_key_value_heads: int = 8 - num_local_experts: int = 8 - rms_norm_eps: float = 1e-5 - rope_theta: float = 1e6 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class MixtralAttention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - self.num_heads = args.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = args.num_key_value_heads - self.rope_theta = args.rope_theta - - self.scale = self.head_dim**-0.5 - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - self.rope = nn.RoPE( - self.head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( - 0, 2, 1, 3 - ) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MixtralSparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_dim = args.hidden_size - self.ffn_dim = args.intermediate_size - self.num_experts = args.num_local_experts - self.num_experts_per_tok = args.num_experts_per_tok - - # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - - self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts) - - def __call__(self, x: mx.array) -> mx.array: - gates = self.gate(x) - - k = self.num_experts_per_tok - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) - scores = mx.take_along_axis(gates, inds, axis=-1) - scores = mx.softmax(scores, axis=-1, precise=True) - - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - - return y - - -class MixtralDecoderLayer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - - self.self_attn = MixtralAttention(args) - - self.block_sparse_moe = MixtralSparseMoeBlock(args) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.block_sparse_moe(self.post_attention_layernorm(h)) - out = h + r - return out - - -class MixtralModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = MixtralModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - self.args = args - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - return self.lm_head(out) - - def sanitize(self, weights): - if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: - return weights - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop( - f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}" - ) - for e in range(self.args.num_local_experts) - ] - weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( - mx.stack(to_join) - ) - return weights - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py deleted file mode 100644 index eabfac8c..00000000 --- a/llms/mlx_lm/models/nemotron.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright © 2024 Apple Inc. - -from dataclasses import dataclass -from functools import partial -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - hidden_act: str - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - norm_eps: float - vocab_size: int - num_key_value_heads: int - head_dim: Optional[int] = None - max_position_embeddings: Optional[int] = None - attention_bias: bool = False - mlp_bias: bool = False - partial_rotary_factor: float = 0.5 - rope_theta: float = 10000.0 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = False - - def __post_init__(self): - if self.rope_scaling: - if not "factor" in self.rope_scaling: - raise ValueError(f"rope_scaling must contain 'factor'") - rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( - "rope_type" - ) - if rope_type is None: - raise ValueError( - f"rope_scaling must contain either 'type' or 'rope_type'" - ) - if rope_type not in ["linear"]: - raise ValueError("rope_scaling 'type' currently only supports 'linear'") - - -@partial(mx.compile, shapeless=True) -def relu_squared(x): - return nn.relu(x).square() - - -class NemotronLayerNorm1P(nn.LayerNorm): - def __call__(self, x): - weight = self.weight + 1 if "weight" in self else None - bias = self.bias if "bias" in self else None - return mx.fast.layer_norm(x, weight, bias, self.eps) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads - self.partial_rotary_factor = args.partial_rotary_factor - - self.scale = head_dim**-0.5 - if hasattr(args, "attention_bias"): - attention_bias = args.attention_bias - else: - attention_bias = False - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - - rope_scale = 1.0 - if args.rope_scaling and args.rope_scaling["type"] == "linear": - assert isinstance(args.rope_scaling["factor"], float) - rope_scale = 1 / args.rope_scaling["factor"] - self.rope = nn.RoPE( - int(self.partial_rotary_factor * self.head_dim), - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, _ = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - hidden_dim = args.intermediate_size - mlp_bias = args.mlp_bias - - self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) - self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) - - def __call__(self, x) -> mx.array: - return self.down_proj(relu_squared(self.up_proj(x))) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args) - self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps) - self.post_attention_layernorm = NemotronLayerNorm1P( - args.hidden_size, eps=args.norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class NemotronModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, cache=c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = NemotronModel(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py deleted file mode 100644 index 4273b0ec..00000000 --- a/llms/mlx_lm/models/olmo.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import sys -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask - -try: - import hf_olmo -except ImportError: - print("To run olmo install ai2-olmo: pip install ai2-olmo") - sys.exit(1) - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - d_model: int - n_layers: int - mlp_hidden_size: int - n_heads: int - vocab_size: int - embedding_size: int - rope_theta: float = 10000 - rope_traditional: bool = False - mlp_ratio: int = 4 - weight_tying: bool = False - - def __post_init__(self): - self.mlp_hidden_size = ( - self.mlp_hidden_size - if self.mlp_hidden_size is not None - else self.mlp_ratio * self.d_model - ) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_heads = args.n_heads - dim = args.d_model - - self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False) - self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False) - - self.att_norm = nn.LayerNorm(dim, affine=False) - self.ff_norm = nn.LayerNorm(dim, affine=False) - - head_dim = dim // self.n_heads - self.scale = head_dim**-0.5 - - self.att_proj = nn.Linear(dim, 3 * dim, bias=False) - self.attn_out = nn.Linear(dim, dim, bias=False) - - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) - - self.args = args - - def attend( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.attn_out(output) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.attend(self.att_norm(x), mask, cache) - h = x + r - - x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1) - - out = h + self.ff_out(nn.silu(x2) * x1) - return out - - -class Transformer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_layers = args.n_layers - self.weight_tying = args.weight_tying - - self.wte = nn.Embedding(args.embedding_size, args.d_model) - self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)] - if not self.weight_tying: - self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False) - self.norm = nn.LayerNorm(args.d_model, affine=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.wte(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.blocks) - - for block, c in zip(self.blocks, cache): - h = block(h, mask, c) - - h = self.norm(h) - - if self.weight_tying: - return self.wte.as_linear(h), cache - - return self.ff_out(h) - - -class OlmoModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.transformer = Transformer(args) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - return self.transformer(inputs, mask, cache) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = OlmoModel(args) - self.args = args - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - return self.model(inputs, mask, cache) - - @property - def layers(self): - return self.model.transformer.blocks diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py deleted file mode 100644 index 510ff882..00000000 --- a/llms/mlx_lm/models/olmo2.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .rope_utils import initialize_rope - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - head_dim: Optional[int] = None - max_position_embeddings: Optional[int] = None - num_key_value_heads: Optional[int] = None - attention_bias: bool = False - mlp_bias: bool = False - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = True - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads - - self.scale = head_dim**-0.5 - if hasattr(args, "attention_bias"): - attention_bias = args.attention_bias - else: - attention_bias = False - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - - self.rope = initialize_rope( - self.head_dim, - args.rope_theta, - args.rope_traditional, - args.rope_scaling, - args.max_position_embeddings, - ) - - self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) - self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - queries = self.q_norm(queries) - keys = self.k_norm(keys) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - hidden_dim = args.intermediate_size - if hasattr(args, "mlp_bias"): - mlp_bias = args.mlp_bias - else: - mlp_bias = False - - self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) - self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) - self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.post_feedforward_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.post_attention_layernorm(self.self_attn(x, mask, cache)) - h = x + r - r = self.post_feedforward_layernorm(self.mlp(h)) - out = h + r - return out - - -class LlamaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache=None, - mask=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, cache=c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = LlamaModel(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - mask=None, - ): - out = self.model(inputs, cache, mask) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - def sanitize(self, weights): - # Remove unused precomputed rotary freqs - return { - k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k - } - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py deleted file mode 100644 index b9c0fc69..00000000 --- a/llms/mlx_lm/models/olmoe.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .rope_utils import initialize_rope -from .switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_experts: int - num_experts_per_tok: int - norm_topk_prob: bool = False - head_dim: Optional[int] = None - max_position_embeddings: Optional[int] = None - num_key_value_heads: Optional[int] = None - attention_bias: bool = False - mlp_bias: bool = False - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = True - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads - - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) - - self.rope = initialize_rope( - self.head_dim, - args.rope_theta, - args.rope_traditional, - args.rope_scaling, - args.max_position_embeddings, - ) - - self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) - self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - queries = self.q_norm(queries) - keys = self.k_norm(keys) - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class OlmoeSparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_experts = args.num_experts - self.top_k = args.num_experts_per_tok - self.norm_topk_prob = args.norm_topk_prob - - self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) - self.switch_mlp = SwitchGLU( - args.hidden_size, - args.intermediate_size, - self.num_experts, - bias=args.mlp_bias, - ) - - def __call__(self, x: mx.array) -> mx.array: - B, L, D = x.shape - x_flat = x.reshape(-1, D) - router_logits = self.gate(x_flat) - routing_weights = mx.softmax(router_logits, axis=1, precise=True) - k = self.top_k - indices = mx.stop_gradient( - mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k] - ) - scores = mx.take_along_axis(routing_weights, indices, axis=-1) - if self.norm_topk_prob: - scores = scores / scores.sum(axis=-1, keepdims=True) - y = self.switch_mlp(x_flat, indices) - y = (y * scores[..., None]).sum(axis=-2) - return y.reshape(B, L, D) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.self_attn = Attention(args) - self.mlp = OlmoeSparseMoeBlock(args) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - x = x + self.self_attn(self.input_layernorm(x), mask, cache) - x = x + self.mlp(self.post_attention_layernorm(x)) - return x - - -class OlmoeModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache=None, - mask=None, - ): - h = self.embed_tokens(inputs) - if mask is None: - mask = create_attention_mask(h, cache) - if cache is None: - cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): - h = layer(h, mask, cache=c) - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = OlmoeModel(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - mask=None, - ): - out = self.model(inputs, cache, mask) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - def sanitize(self, weights): - if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: - return weights - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for n in ["up_proj", "down_proj", "gate_proj"]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") - for e in range(self.args.num_experts) - ] - weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) - return weights - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py deleted file mode 100644 index 504fe95c..00000000 --- a/llms/mlx_lm/models/openelm.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - head_dim: int - num_transformer_layers: int - model_dim: int - vocab_size: int - ffn_dim_divisor: int - num_query_heads: List - num_kv_heads: List - ffn_multipliers: List - ffn_with_glu: bool = True - normalize_qk_projections: bool = True - share_input_output_layers: bool = True - rms_norm_eps: float = 1e-6 - rope_freq_constant: float = 10000 - - -def make_divisible( - v: Union[float, int], - divisor: Optional[int] = 8, - min_value: Optional[Union[float, int]] = None, -) -> Union[float, int]: - """ - This function is taken from the original tf repo. - It ensures that all layers have a channel number that is divisible by the divisor - It can be seen at: - https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62 - Args: - v: input value - divisor: default to 8 - min_value: minimum divisor value - Returns: - new_v: new divisible value - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): - super().__init__() - self.head_dim = head_dim = args.head_dim - self.layer_id = layer_id - self.model_dim = model_dim = args.model_dim - - self.n_heads = n_heads = args.num_query_heads[layer_id] - self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id] - self.scale = head_dim**-0.5 - - op_size = (n_heads + (n_kv_heads * 2)) * head_dim - self.qkv_proj = nn.Linear(model_dim, op_size, bias=False) - self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False) - - self.normalize_qk_projections = args.normalize_qk_projections - - if self.normalize_qk_projections: - self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) - self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) - - self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - qkv = self.qkv_proj(x) - - qkv = qkv.reshape( - B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim - ).transpose(0, 2, 1, 3) - - queries, keys, values = mx.split( - qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1 - ) - - # Prepare the queries, keys and values for the attention computation - if self.normalize_qk_projections: - queries = self.q_norm(queries) - keys = self.k_norm(keys) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.out_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): - super().__init__() - self.args = args - dim = args.model_dim - ffn_multiplier = args.ffn_multipliers[layer_id] - - intermediate_dim = int( - make_divisible( - ffn_multiplier * args.model_dim, - divisor=args.ffn_dim_divisor, - ) - ) - - self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False) - self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False) - - def __call__(self, x) -> mx.array: - x = self.proj_1(x) - gate, x = mx.split(x, 2, axis=-1) - return self.proj_2(nn.silu(gate) * x) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): - super().__init__() - dim = args.model_dim - self.attn = Attention(args, layer_id=layer_id) - self.ffn = MLP(args, layer_id=layer_id) - self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps) - self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.attn(self.attn_norm(x), mask, cache) - h = x + r - r = self.ffn(self.ffn_norm(h)) - out = h + r - return out - - -class OpenELMModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_transformer_layers = args.num_transformer_layers - assert self.vocab_size > 0 - self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim) - self.layers = [ - TransformerBlock(args, layer_id=layer_id) - for layer_id in range(self.num_transformer_layers) - ] - self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.token_embeddings(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, cache=c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.transformer = OpenELMModel(args) - if not args.share_input_output_layers: - self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.transformer(inputs, mask, cache) - if self.args.share_input_output_layers: - out = self.transformer.token_embeddings.as_linear(out) - else: - out = self.lm_head(out) - - return out - - @property - def layers(self): - return self.transformer.layers diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py deleted file mode 100644 index e9724691..00000000 --- a/llms/mlx_lm/models/phi.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from dataclasses import dataclass -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "phi" - max_position_embeddings: int = 2048 - vocab_size: int = 51200 - hidden_size: int = 2560 - num_attention_heads: int = 32 - num_hidden_layers: int = 32 - num_key_value_heads: int = 32 - partial_rotary_factor: float = 0.4 - intermediate_size: int = 10240 - layer_norm_eps: float = 1e-5 - rope_theta: float = 10000.0 - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class PhiAttention(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.repeats = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.partial_rotary_factor = config.partial_rotary_factor - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=True - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True - ) - self.dense = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=True - ) - - self.rope = nn.RoPE( - int(self.partial_rotary_factor * self.head_dim), - traditional=False, - base=self.rope_theta, - ) - - def __call__(self, x, mask=None, cache=None): - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Extract some shapes - B, L, D = queries.shape - n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape( - B, - L, - n_heads, - -1, - ).moveaxis(1, 2) - keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2) - values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scale = math.sqrt(1 / queries.shape[-1]) - output = scaled_dot_product_attention( - queries.astype(mx.float32), - keys, - values, - cache=cache, - scale=scale, - mask=mask, - ).astype(values.dtype) - - output = output.moveaxis(2, 1).reshape(B, L, -1) - - return self.dense(output) - - -class PhiMLP(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - self.act = nn.GELU(approx="precise") - - def __call__(self, x) -> mx.array: - return self.fc2(self.act(self.fc1(x))) - - -class PhiDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.self_attn = PhiAttention(config=config) - self.input_layernorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.mlp = PhiMLP(config) - - def __call__(self, x, mask, cache): - h = self.input_layernorm(x) - attn_h = self.self_attn(h, mask, cache) - ff_h = self.mlp(h) - return attn_h + ff_h + x - - -class PhiModel(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)] - self.final_layernorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - - def __call__(self, x, mask, cache): - x = self.embed_tokens(x) - - if mask is None: - mask = create_attention_mask(x, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - x = layer(x, mask, c) - return self.final_layernorm(x) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model_type = config.model_type - self.model = PhiModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) - self.args = config - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache=None, - ) -> mx.array: - y = self.model(x, mask, cache) - return self.lm_head(y) - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py deleted file mode 100644 index 63e985de..00000000 --- a/llms/mlx_lm/models/phi3.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .su_rope import SuScaledRotaryEmbedding - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: Optional[int] = None - rope_theta: float = 10000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None - partial_rotary_factor: float = 1.0 - max_position_embeddings: int = 131072 - original_max_position_embeddings: int = 4096 - tie_word_embeddings: bool = False - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"long_factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] not in ["longrope", "su", "linear"]: - print( - "[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false." - ) - self.rope_scaling = None - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.num_hidden_layers = args.num_hidden_layers - - self.head_dim = head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) - self.qkv_proj = nn.Linear(dim, op_size, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - - rope_dim = int(head_dim * args.partial_rotary_factor) - if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: - self.rope = SuScaledRotaryEmbedding( - rope_dim, - base=args.rope_theta, - max_position_embeddings=args.max_position_embeddings, - original_max_position_embeddings=args.original_max_position_embeddings, - short_factor=args.rope_scaling["short_factor"], - long_factor=args.rope_scaling["long_factor"], - ) - else: - rope_scale = 1.0 - if args.rope_scaling and args.rope_scaling["type"] == "linear": - assert isinstance(args.rope_scaling["factor"], float) - rope_scale = 1 / args.rope_scaling["factor"] - self.rope = nn.RoPE( - rope_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - qkv = self.qkv_proj(x) - query_pos = self.n_heads * self.head_dim - queries, keys, values = mx.split( - qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 - ) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - - def __call__(self, x) -> mx.array: - x = self.gate_up_proj(x) - gate, x = mx.split(x, 2, axis=-1) - return self.down_proj(nn.silu(gate) * x) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class Phi3Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = Phi3Model(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - self.args = args - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py deleted file mode 100644 index cd566eec..00000000 --- a/llms/mlx_lm/models/phi3small.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from dataclasses import dataclass -from functools import partial -from typing import Any, Optional - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - dense_attention_every_n_layers: int - ff_intermediate_size: int - gegelu_limit: float - num_hidden_layers: int - num_attention_heads: int - layer_norm_epsilon: float - vocab_size: int - num_key_value_heads: int - mup_attn_multiplier: float = 1.0 - mup_use_scaling: bool = True - mup_embedding_multiplier: float = 10.0 - mup_width_multiplier: float = 8.0 - rope_embedding_base: float = 1000000 - rope_position_scale: float = 1.0 - blocksparse_block_size: int = 64 - blocksparse_num_local_blocks: int = 16 - blocksparse_vert_stride: int = 8 - - -@partial(mx.compile, shapeless=True) -def gegelu_impl(a_gelu, a_linear, limit): - a_gelu = mx.where( - mx.isinf(a_gelu), - a_gelu, - mx.clip(a_gelu, a_min=None, a_max=limit), - ) - a_linear = mx.where( - mx.isinf(a_linear), - a_linear, - mx.clip(a_linear, a_min=-limit, a_max=limit), - ) - out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu) - return out_gelu * (a_linear + 1.0) - - -def gegelu(x, limit): - a_gelu, a_linear = x[..., ::2], x[..., 1::2] - return gegelu_impl(a_gelu, a_linear, limit) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_idx): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.n_q_per_kv = n_heads // n_kv_heads - - self.head_dim = head_dim = args.hidden_size // n_heads - - self.query_key_value = nn.Linear( - dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim - ) - self.dense = nn.Linear(dim, dim) - - if args.mup_use_scaling: - norm_factor = head_dim / args.mup_attn_multiplier - else: - norm_factor = math.sqrt(head_dim) - self.scale = 1.0 / norm_factor - - self.rope = nn.RoPE( - head_dim, - traditional=False, - base=args.rope_embedding_base, - scale=args.rope_position_scale, - ) - - if layer_idx % args.dense_attention_every_n_layers == 0: - self.block_sparse = True - self.blocksparse_block_size = args.blocksparse_block_size - if self.blocksparse_block_size not in (32, 64): - raise ValueError( - f"Unsupported block size {self.blocksparse_block_size}" - ) - self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks - self.blocksparse_vert_stride = args.blocksparse_vert_stride - else: - self.block_sparse = False - - def _block_sparse_mask(self, q_len, kv_len): - vert_stride = self.blocksparse_vert_stride - local_blocks = self.blocksparse_num_local_blocks - block_size = self.blocksparse_block_size - n_heads = self.n_heads - - kv_blocks = (kv_len + block_size - 1) // block_size - q_blocks = (q_len + block_size - 1) // block_size - q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None] - k_pos = mx.arange(kv_blocks)[None, None] - - mask_vert_strided = ( - mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None] - ) % vert_stride - mask_vert_strided = (mask_vert_strided == 0)[:, None, :] - - block_mask = (q_pos >= k_pos) & ( - (q_pos - k_pos < local_blocks) | mask_vert_strided - ) - block_mask = block_mask.reshape( - self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:] - ) - dense_mask = mx.repeat( - mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2 - ) - return block_mask, dense_mask[..., -q_len:, :kv_len] - - def _block_sparse_attention(self, queries, keys, values, scale, mask): - queries = scale * queries - B = queries.shape[0] - L = queries.shape[2] - queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1)) - keys = mx.expand_dims(keys, 2) - values = mx.expand_dims(values, 2) - - # TODO get rid of dense mask if we have a fill value - block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2]) - scores = queries @ mx.swapaxes(keys, -1, -2) - # TODO, uncomment when faster - # scores = mx.block_masked_mm( - # queries, - # mx.swapaxes(keys, -1, -2), - # mask_out=block_mask, - # block_size=self.blocksparse_block_size, - # ) - - if mask is not None: - scores = scores + mask - scores = scores + mx.where( - dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype) - ) - scores = mx.softmax(scores, axis=-1, precise=True) - - output = scores @ values - # TODO, uncomment when faster - # output = mx.block_masked_mm( - # scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size - # ) - return mx.reshape(output, (B, self.n_heads, L, -1)) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - qkv = self.query_key_value(x) - qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim) - queries = qkv[..., :-2, :].flatten(-3, -2) - keys = qkv[..., -2, :] - values = qkv[..., -1, :] - - # Prepare the queries, keys and values for the attention computation - queries = queries.transpose(0, 2, 1, 3) - keys = keys.transpose(0, 2, 1, 3) - values = values.transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - if self.block_sparse: - output = self._block_sparse_attention( - queries, keys, values, scale=self.scale, mask=mask - ) - else: - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.dense(output) - - -class MLP(nn.Module): - def __init__(self, args): - super().__init__() - dim = args.hidden_size - hidden_dim = args.ff_intermediate_size - self.gegelu_limit = args.gegelu_limit - self.up_proj = nn.Linear(dim, 2 * hidden_dim) - self.down_proj = nn.Linear(hidden_dim, dim) - - def __call__(self, x) -> mx.array: - x = self.up_proj(x) - return self.down_proj(gegelu(x, self.gegelu_limit)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs, layer_idx): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args, layer_idx) - self.mlp = MLP(args) - self.input_layernorm = nn.LayerNorm( - args.hidden_size, eps=args.layer_norm_epsilon - ) - self.post_attention_layernorm = nn.LayerNorm( - args.hidden_size, - eps=args.layer_norm_epsilon, - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class Phi3Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.mup_embedding_multiplier = args.mup_embedding_multiplier - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args, layer_idx=l) - for l in range(args.num_hidden_layers) - ] - self.final_layernorm = nn.LayerNorm( - args.hidden_size, eps=args.layer_norm_epsilon - ) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - if self.mup_embedding_multiplier: - h = self.mup_embedding_multiplier * h - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.final_layernorm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = Phi3Model(args) - self.args = args - self.mup_width_multiplier = args.mup_width_multiplier - self._dummy_tokenizer_ids = mx.array( - [100256, 100258, 100259, 100260, 100264, 100265] - + list(range(100267, 100352)) - ) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - out = self.model.embed_tokens.as_linear(out) - if self.mup_width_multiplier: - out = out / self.mup_width_multiplier - out[self._dummy_tokenizer_ids] = -float("inf") - return out - - @property - def layers(self): - return self.model.layers - - def sanitize(self, weights): - # Remove unused precomputed rotary freqs - return { - k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k - } diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py deleted file mode 100644 index bddcb128..00000000 --- a/llms/mlx_lm/models/phimoe.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright © 2024 Apple Inc. -import math -from dataclasses import dataclass -from typing import Dict, List, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .su_rope import SuScaledRotaryEmbedding -from .switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "phimoe" - vocab_size: int = 32064 - hidden_size: int = 4096 - intermediate_size: int = 6400 - num_hidden_layers: int = 32 - num_attention_heads: int = 32 - num_key_value_heads: int = 8 - max_position_embeddings: int = 131072 - original_max_position_embeddings: int = 4096 - rms_norm_eps: float = 1e-6 - rope_scaling: Dict[str, Union[float, List[float]]] = None - num_local_experts: int = 16 - num_experts_per_tok: int = 2 - rope_theta: float = 10000.0 - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) - - self.rope = SuScaledRotaryEmbedding( - head_dim, - base=args.rope_theta, - max_position_embeddings=args.max_position_embeddings, - original_max_position_embeddings=args.original_max_position_embeddings, - short_factor=args.rope_scaling["short_factor"], - long_factor=args.rope_scaling["long_factor"], - short_mscale=args.rope_scaling["short_mscale"], - long_mscale=args.rope_scaling["long_mscale"], - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache=None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class PhiMoESparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_dim = args.hidden_size - self.ffn_dim = args.intermediate_size - self.num_experts = args.num_local_experts - self.top_k = args.num_experts_per_tok - - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts) - - def __call__(self, x: mx.array) -> mx.array: - gates = self.gate(x) - - k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) - scores = mx.take_along_axis(gates, inds, axis=-1) - scores = mx.softmax(scores, axis=-1, precise=True) - - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - - return y - - -class PhiMoEDecoderLayer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - - self.self_attn = Attention(args) - self.block_sparse_moe = PhiMoESparseMoeBlock(args) - self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.LayerNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache=None, - ) -> mx.array: - residual = x - hidden_states = self.input_layernorm(x) - hidden_states = self.self_attn(hidden_states, mask=mask, cache=cache) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.block_sparse_moe(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - - -class PhiMoEModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)] - self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ) -> mx.array: - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.args = args - self.model = PhiMoEModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - return self.lm_head(out) - - def sanitize(self, weights): - if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: - return weights - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop( - f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}" - ) - for e in range(self.args.num_local_experts) - ] - weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( - mx.stack(to_join) - ) - - return weights - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py deleted file mode 100644 index 5477c2c0..00000000 --- a/llms/mlx_lm/models/phixtral.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import inspect -import math -from dataclasses import dataclass -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import create_attention_mask, scaled_dot_product_attention -from .switch_layers import SwitchMLP - - -@dataclass -class ModelArgs: - model_type: str - num_vocab: int = 51200 - model_dim: int = 2560 - num_heads: int = 32 - num_layers: int = 32 - rotary_dim: int = 32 - num_experts_per_tok: int = 2 - num_local_experts: int = 4 - - @classmethod - def from_dict(cls, params): - return cls( - **{ - k: v - for k, v in params.items() - if k in inspect.signature(cls).parameters - } - ) - - -class RoPEAttention(nn.Module): - def __init__(self, dims: int, num_heads: int, rotary_dim: int): - super().__init__() - - self.num_heads = num_heads - - self.rope = nn.RoPE(rotary_dim, traditional=False) - self.Wqkv = nn.Linear(dims, 3 * dims) - self.out_proj = nn.Linear(dims, dims) - - def __call__(self, x, mask=None, cache=None): - qkv = self.Wqkv(x) - queries, keys, values = mx.split(qkv, 3, axis=-1) - - # Extract some shapes - num_heads = self.num_heads - B, L, D = queries.shape - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - queries = queries.astype(mx.float32) - - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - - output = scaled_dot_product_attention( - queries.astype(mx.float32), - keys, - values, - cache=cache, - scale=scale, - mask=mask, - ).astype(values.dtype) - output = output.moveaxis(2, 1).reshape(B, L, -1) - - return self.out_proj(output) - - -class MOE(nn.Module): - def __init__(self, args: ModelArgs, dim: int, hidden_dim: int): - super().__init__() - self.dim = dim - self.hidden_dim = hidden_dim - self.num_experts = args.num_local_experts - self.num_experts_per_tok = args.num_experts_per_tok - self.switch_mlp = SwitchMLP( - self.dim, self.hidden_dim, self.num_experts, bias=True - ) - self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False) - - def __call__(self, x: mx.array) -> mx.array: - gates = self.gate(x) - - k = self.num_experts_per_tok - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k] - scores = mx.take_along_axis(gates, inds, axis=-1) - scores = mx.softmax(scores, axis=-1, precise=True) - - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - - return y - - -class ParallelBlock(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - dims = config.model_dim - mlp_dims = dims * 4 - self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) - self.ln = nn.LayerNorm(dims) - self.moe = MOE(config, dims, mlp_dims) - - def __call__(self, x, mask, cache): - h = self.ln(x) - attn_h = self.mixer(h, mask, cache) - ff_h = self.moe(h) - return attn_h + ff_h + x - - -class TransformerDecoder(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.embd = Embd(config) - self.h = [ParallelBlock(config) for i in range(config.num_layers)] - - def __call__(self, x, mask, cache): - x = self.embd(x) - if cache is None: - cache = [None] * len(self.h) - - for layer, c in zip(self.h, cache): - x = layer(x, mask, c) - return x - - -class Embd(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.wte = nn.Embedding(config.num_vocab, config.model_dim) - - def __call__(self, x): - return self.wte(x) - - -class OutputHead(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.ln = nn.LayerNorm(config.model_dim) - self.linear = nn.Linear(config.model_dim, config.num_vocab) - - def __call__(self, inputs): - return self.linear(self.ln(inputs)) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model_type = config.model_type - self.transformer = TransformerDecoder(config) - self.lm_head = OutputHead(config) - self.args = config - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache=None, - ) -> mx.array: - - if mask is None: - mask = create_attention_mask(x, cache) - - y = self.transformer(x, mask, cache) - return self.lm_head(y) - - def sanitize(self, weights): - if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights: - return weights - for l in range(self.args.num_layers): - prefix = f"transformer.h.{l}" - for n in ["fc1", "fc2"]: - for k in ["weight", "scales", "biases", "bias"]: - if f"{prefix}.moe.mlp.0.{n}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}") - for e in range(self.args.num_local_experts) - ] - weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join) - return weights - - @property - def layers(self): - return self.transformer.h diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py deleted file mode 100644 index 9107daad..00000000 --- a/llms/mlx_lm/models/plamo.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Optional - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - n_shared_head: int = 8 - rope_theta: float = 10000 - rope_traditional: bool = False - - -class Attention(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - head_dim = self.hidden_size // config.num_attention_heads - - self.q_num_heads = config.num_attention_heads - self.qk_dim = self.v_dim = head_dim - self.k_num_heads = self.v_num_heads = int( - np.ceil(self.q_num_heads / config.n_shared_head) - ) - - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear( - self.hidden_size, self.q_num_heads * self.qk_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.k_num_heads * self.qk_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.v_num_heads * self.v_dim, bias=False - ) - self.o_proj = nn.Linear( - self.q_num_heads * self.v_dim, self.hidden_size, bias=False - ) - self.rotary_emb = nn.RoPE( - head_dim, - traditional=config.rope_traditional, - base=config.rope_theta, - scale=1.0, - ) - - def __call__( - self, - hidden_states: mx.array, - attention_mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - bsz, q_len, _ = hidden_states.shape - - queries = self.q_proj(hidden_states) - keys = self.k_proj(hidden_states) - values = self.v_proj(hidden_states) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose( - 0, 2, 1, 3 - ) - keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose( - 0, 2, 1, 3 - ) - values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose( - 0, 2, 1, 3 - ) - - if cache is not None: - queries = self.rotary_emb(queries, offset=cache.offset) - keys = self.rotary_emb(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rotary_emb(queries) - keys = self.rotary_emb(keys) - - keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) - values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) - - output = scaled_dot_product_attention( - queries, - keys, - values, - cache=cache, - scale=self.scale, - mask=attention_mask, - ) - output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - - def __call__(self, x: mx.array) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore - - -class PlamoDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.self_attn = Attention(config) - self.mlp = MLP(config) - self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def __call__( - self, - hidden_states: mx.array, - attention_mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ): - # from LlamaDecoder - residual = hidden_states - - hidden_states = self.norm(hidden_states) - - # Self Attention - hidden_states_sa = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - cache=cache, - ) - - # Fully Connected - hidden_states_mlp = self.mlp(hidden_states) - - hidden_states = residual + hidden_states_sa + hidden_states_mlp - return hidden_states - - -class PlamoDecoder(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.layers = [ - PlamoDecoderLayer(config) for _ in range(config.num_hidden_layers) - ] - - -class PlamoModel(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = PlamoDecoder(config) # type: ignore - self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache: Optional[Any] = None, - mask: Optional[mx.array] = None, - ) -> mx.array: - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None for _ in range(len(self.layers.layers))] - - for layer, c in zip(self.layers.layers, cache): - h = layer(h, mask, cache=c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs) -> None: - super().__init__() - self.model_type = args.model_type - self.model = PlamoModel(args) - self.lm_head: nn.Module = nn.Linear( - args.hidden_size, args.vocab_size, bias=False - ) - self.args = args - - def __call__( - self, - inputs: mx.array, - cache: Optional[Any] = None, - mask: Optional[mx.array] = None, - ) -> mx.array: - out = self.model(inputs, cache, mask) - return self.lm_head(out) - - @property - def layers(self): - return self.model.layers.layers diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py deleted file mode 100644 index 657fa02e..00000000 --- a/llms/mlx_lm/models/plamo2.py +++ /dev/null @@ -1,608 +0,0 @@ -# Copyright © 2025 Apple Inc. - -import math -from dataclasses import dataclass -from typing import Any, Optional - -import mlx.core as mx -import mlx.nn as nn -from mlx_lm.models.base import BaseModelArgs, create_attention_mask - -from .cache import KVCache, MambaCache - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "plamo2" - hidden_size: int = 4096 - num_hidden_layers: int = 32 - rms_norm_eps: float = 1e-6 - tie_word_embeddings: bool = True - num_attention_heads: int = 32 - num_key_value_heads: int = 4 - hidden_size_per_head: int = 128 - max_position_embeddings: int = 2048 - attention_window_size: int = 2048 - full_attention_idx: Optional[list[int]] = None - mamba_d_state: int = 64 - mamba_d_conv: int = 4 - mamba_num_heads: int = 64 - mamba_step: int = 2 - mamba_chunk_size: int = 256 - mamba_enabled: bool = True - intermediate_size: int = 13312 - vocab_size: int = 32000 - - -class RMSNorm(nn.Module): - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - offset: float = 1.0, - ) -> None: - super().__init__() - self.weight = mx.zeros(hidden_size) - self.variance_epsilon = eps - self.offset = offset - - def __call__(self, hidden_states: mx.array) -> mx.array: - return mx.fast.rms_norm( - hidden_states, self.weight + self.offset, self.variance_epsilon - ) - - -def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.astype(mx.float32) - variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + eps) - hidden_states = hidden_states.astype(input_dtype) - - return hidden_states - - -def get_initial_dt_bias(num_heads: int) -> mx.array: - dt_min = 0.001 - dt_max = 0.1 - dt = mx.exp( - mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ) - dt = mx.clip(dt, a_min=1e-4, a_max=None) - inv_dt = dt + mx.log(-mx.expm1(-dt)) - return inv_dt - - -def get_initial_A(num_heads: int) -> mx.array: - A = mx.arange(1, num_heads + 1, dtype=mx.float32) - return mx.log(A) - - -# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219 -def selective_state_update_ref( - state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False -) -> tuple[mx.array, mx.array]: - """ - Argument: - state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) - dt: (batch, dim) or (batch, nheads, dim) - A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) - C: (batch, dstate) or (batch, ngroups, dstate) - D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) - dt_bias: (dim,) or (nheads, dim) - Return: - out: (batch, dim) or (batch, nheads, dim) - """ - has_heads = state.ndim > 3 - if state.ndim == 3: - state = mx.expand_dims(state, 1) - if x.ndim == 2: - x = mx.expand_dims(x, 1) - if dt.ndim == 2: - dt = mx.expand_dims(dt, 1) - if A.ndim == 2: - A = mx.expand_dims(A, 0) - if B.ndim == 2: - B = mx.expand_dims(B, 1) - if C.ndim == 2: - C = mx.expand_dims(C, 1) - if D is not None and D.ndim == 1: - D = mx.expand_dims(D, 0) - if z is not None and z.ndim == 2: - z = mx.expand_dims(z, 1) - if dt_bias is not None and dt_bias.ndim == 1: - dt_bias = mx.expand_dims(dt_bias, 0) - batch, nheads, dim, dstate = state.shape - assert x.shape == (batch, nheads, dim) - assert dt.shape == x.shape - assert A.shape == (nheads, dim, dstate) - ngroups = B.shape[1] - assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert B.shape == (batch, ngroups, dstate) - assert C.shape == B.shape - if D is not None: - assert D.shape == (nheads, dim) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (nheads, dim) - dt = dt + dt_bias - dt = nn.softplus(dt) if dt_softplus else dt - dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate) - B = mx.reshape( - mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2), - (batch, nheads, dstate), - ) # (batch, nheads, dstate) - C = mx.reshape( - mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2), - (batch, nheads, dstate), - ) # (batch, nheads, dstate) - dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims( - B, axis=-2 - ) # (batch, nheads, dim, dstate) - state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate) - out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C) - if D is not None: - out += (x * D).astype(out.dtype) - out = (out if z is None else out * nn.silu(z)).astype(x.dtype) - if not has_heads: - out = out.squeeze(1) - return out, state - - -def ssd_update_state( - ssm_state: mx.array, - x: mx.array, - dt: mx.array, - A: mx.array, - B: mx.array, - C: mx.array, - D: mx.array, - z: mx.array, - dt_bias: mx.array, - dt_softplus: bool, -) -> tuple[mx.array, mx.array]: - assert ssm_state.dtype == mx.float32 - dtype = x.dtype - - hidden_size_per_head = x.shape[-1] - d_state = B.shape[-1] - A = mx.broadcast_to( - A[:, None, None], (A.shape[0], hidden_size_per_head, d_state) - ).astype(mx.float32) - dt = mx.broadcast_to( - dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head) - ) - dt_bias = mx.broadcast_to( - dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head) - ) - D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head)) - out, ssm_state = selective_state_update_ref( - ssm_state, - x.astype(dtype), - dt.astype(dtype), - A.astype(mx.float32), - B.astype(dtype), - C.astype(dtype), - D.astype(mx.float32), - z.astype(dtype), - dt_bias.astype(mx.float32), - dt_softplus=dt_softplus, - ) - return out[:, None], ssm_state - - -def ssd_chunk_scan_combined( - x: mx.array, - dt: mx.array, - A: mx.array, - B: mx.array, - C: mx.array, - D: mx.array, - z: mx.array, - dt_bias: mx.array, - dt_softplus: bool, - ssm_state: mx.array, -) -> tuple[mx.array, mx.array]: - assert ssm_state.dtype == mx.float32 - length = x.shape[1] - ys = [] - for i in range(length): - y, ssm_state = ssd_update_state( - ssm_state, - x[:, i], - dt[:, i], - A, - B[:, i], - C[:, i], - D if D.ndim == 1 else D[:, i], - z=z[:, i], - dt_bias=dt_bias, - dt_softplus=dt_softplus, - ) - ys.append(y) - return mx.concatenate(ys, axis=1), ssm_state - - -def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: - _, seqlen, dim = x.shape - state_len = conv_state.shape[-2] - x = mx.concatenate([conv_state, x], axis=-2) - conv_state = x[:, -state_len:] - out = mx.conv1d( - x, - weight, - padding=0, - groups=dim, - )[:, -seqlen:] - return nn.silu(out), conv_state - - -class Mamba(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.d_state = config.mamba_d_state - self.d_conv = config.mamba_d_conv - self.chunk_size = config.mamba_chunk_size - self.num_heads = config.mamba_num_heads - self.hidden_size_per_head = config.hidden_size_per_head - - self.intermediate_size = self.num_heads * self.hidden_size_per_head - - self.in_proj = nn.Linear( - self.hidden_size, 2 * self.intermediate_size, bias=False - ) - self.conv1d = nn.Conv1d( - in_channels=self.intermediate_size, - out_channels=self.intermediate_size, - bias=False, - kernel_size=self.d_conv, - groups=self.intermediate_size, - padding=0, - ) - self.dt_dim = max(64, self.hidden_size // 16) - self.bcdt_proj = nn.Linear( - self.intermediate_size, - self.dt_dim + 2 * self.d_state, - bias=False, - ) - self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False) - - self.dt_bias = get_initial_dt_bias(self.num_heads) - self.A_log = get_initial_A(self.num_heads) - self.D = mx.ones(self.num_heads, dtype=mx.float32) - - self.dt_norm_weight = mx.ones(self.dt_dim) - self.B_norm_weight = mx.ones(self.d_state) - self.C_norm_weight = mx.ones(self.d_state) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - - def __call__( - self, - hidden_states: mx.array, - mask: Optional[mx.array] = None, - cache=None, - ): - bsize, length, _ = hidden_states.shape - - if cache is not None and cache[0] is not None: - conv_state = cache[0] - ssm_state = cache[1] - else: - conv_state = mx.zeros( - (bsize, self.d_conv - 1, self.intermediate_size), - dtype=hidden_states.dtype, - ) - ssm_state = mx.zeros( - (bsize, self.num_heads, self.hidden_size_per_head, self.d_state), - dtype=mx.float32, - ) - - zx = self.in_proj(hidden_states) - zx = zx.reshape(bsize, length, self.num_heads, -1) - # z: (bsize, length, num_heads, hidden_size_per_head) - # x: (bsize, length, num_heads, hidden_size_per_head) - z, x = mx.split( - zx, - [ - self.hidden_size_per_head, - ], - axis=-1, - ) - - x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head) - x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight) - BCdt = self.bcdt_proj(x) - x = x.reshape(bsize, length, self.num_heads, -1) - B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1) - - A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) - dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps) - B = mx.fast.rms_norm(B, self.B_norm_weight, self.config.rms_norm_eps) - C = mx.fast.rms_norm(C, self.C_norm_weight, self.config.rms_norm_eps) - - # (bsize, length, num_heads, 1) - dt = self.dt_proj(dt)[..., None] - - out, ssm_state = ssd_chunk_scan_combined( - x, - dt.reshape(bsize, length, -1), - A, - B, - C, - D=self.D, - z=z, - dt_bias=self.dt_bias, - dt_softplus=True, - ssm_state=ssm_state, - ) - - if cache is not None: - cache[0] = conv_state - cache[1] = ssm_state - y = self.out_proj(out.reshape(bsize, length, -1)) - - return y - - -class Attention(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - head_dim = config.hidden_size_per_head - self.max_position_embeddings = config.max_position_embeddings - self.scale = head_dim**-0.5 - - self.q_num_heads = config.num_attention_heads - self.qk_dim = self.v_dim = head_dim - self.k_num_heads = self.v_num_heads = config.num_key_value_heads - assert self.q_num_heads % self.k_num_heads == 0 - self.n_group = self.q_num_heads // self.k_num_heads - - self.q_proj_dim = self.q_num_heads * self.qk_dim - self.k_proj_dim = self.k_num_heads * self.qk_dim - self.v_proj_dim = self.k_num_heads * self.v_dim - self.qkv_proj = nn.Linear( - self.hidden_size, - self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, - bias=False, - ) - self.o_proj = nn.Linear( - self.q_num_heads * self.v_dim, self.hidden_size, bias=False - ) - - self.q_weight = mx.ones((self.q_num_heads, self.qk_dim)) - self.k_weight = mx.ones((self.k_num_heads, self.qk_dim)) - - self.rope = nn.RoPE(self.qk_dim) - - def __call__( - self, - hidden_states: mx.array, - mask: Optional[mx.array] = None, - cache=None, - ): - B, T, _ = hidden_states.shape - - qkv = self.qkv_proj(hidden_states) - q, k, v = mx.split( - qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1 - ) - q = q.reshape(B, T, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3) - k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) - v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) - - q = _rms_norm(q, 1e-6) * self.q_weight[:, None] - k = _rms_norm(k, 1e-6) * self.k_weight[:, None] - - if cache is not None: - q = self.rope(q, offset=cache.offset) - k = self.rope(k, offset=cache.offset) - k, v = cache.update_and_fetch(k, v) - else: - q = self.rope(q) - k = self.rope(k) - - output = mx.fast.scaled_dot_product_attention( - q, - k, - v, - scale=self.scale, - mask=mask, - ) - output = output.transpose(0, 2, 1, 3).reshape( - B, T, self.q_num_heads * self.v_dim - ) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_up_proj = nn.Linear( - self.hidden_size, self.intermediate_size * 2, bias=False - ) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - - def __call__(self, x: mx.array) -> mx.array: - h = self.gate_up_proj(x) - hs = mx.split(h, 2, axis=-1) - return self.down_proj(nn.silu(hs[0]) * hs[1]) - - -class PlamoDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs, is_mamba: bool) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.is_mamba = is_mamba - self.mixer: nn.Module - if is_mamba: - self.mixer = Mamba(config) - else: - self.mixer = Attention(config) - self.mlp = MLP(config) - self.pre_mixer_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, offset=1.0 - ) - self.post_mixer_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5 - ) - self.pre_mlp_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, offset=1.0 - ) - self.post_mlp_norm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5) - ) - - def __call__( - self, - hidden_states: mx.array, - mask: Optional[mx.array] = None, - cache=None, - ): - residual = hidden_states - hidden_states = self.pre_mixer_norm(hidden_states) - - hidden_states_sa = self.mixer( - hidden_states=hidden_states, - mask=mask, - cache=cache, - ) - - hidden_states_sa = self.post_mixer_norm(hidden_states_sa) - hidden_states = residual + hidden_states_sa - - residual = hidden_states - hidden_states = self.pre_mlp_norm(hidden_states) - - # Fully Connected - hidden_states_mlp = self.mlp(hidden_states) - - # Residual - hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp) - return residual + hidden_states_mlp - - -def is_mamba(config: ModelArgs, i: int) -> bool: - if not config.mamba_enabled: - return False - assert config.mamba_step > 1 - assert i < config.num_hidden_layers - - if config.num_hidden_layers <= (config.mamba_step // 2): - # use attention in last layer - return i != config.num_hidden_layers - 1 - return (i % config.mamba_step) != (config.mamba_step // 2) - - -class PlamoDecoder(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - - self.layers = [ - PlamoDecoderLayer(config, is_mamba=is_mamba(config, i)) - for i in range(config.num_hidden_layers) - ] - - def __call__(self, x: mx.array, mask: mx.array, cache): - for i, decoder_layer in enumerate(self.layers): - x = decoder_layer( - x, - mask=mask, - cache=cache[i], - ) - return x - - -class PlamoModel(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - - self.config = config - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = PlamoDecoder(config) # type: ignore - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: Optional[mx.array] = None, - cache=None, - ): - batch_size, seq_length = inputs.shape - - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, [cache[1]] if cache is not None else None) - - if cache is None: - cache = [None] * len(self.layers.layers) - - # decoder layers - out = self.layers( - h, - mask, - cache, - ) - - return self.norm(out) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.model_type = config.model_type - self.model = PlamoModel(config) - - self.vocab_size = config.vocab_size - - if not config.tie_word_embeddings: - self.lm_head: nn.Module = nn.Linear( - config.hidden_size, self.vocab_size, bias=False - ) - - def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: - for k, v in weights.items(): - if "conv1d.weight" in k and v.shape[-1] != 1: - weights[k] = v.moveaxis(2, 1) - return weights - - def make_cache(self): - # TODO use RotatingKVCache is not full_attn - # full_attn = self.layer_idx in self.config.full_attention_idx - return [MambaCache() if l.is_mamba else KVCache() for l in self.layers] - - def __call__( - self, inputs: mx.array, mask: Optional[mx.array] = None, cache=None - ) -> mx.array: - outputs = self.model( - inputs=inputs, - mask=None, - cache=cache, - ) - if self.config.tie_word_embeddings: - logits = self.model.embed_tokens.as_linear(outputs) - else: - logits = self.lm_head(outputs) - - return logits - - @property - def layers(self): - return self.model.layers.layers diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py deleted file mode 100644 index ec8a0199..00000000 --- a/llms/mlx_lm/models/qwen.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int = 2048 - num_attention_heads: int = 16 - num_hidden_layers: int = 24 - kv_channels: int = 128 - max_position_embeddings: int = 8192 - layer_norm_epsilon: float = 1e-6 - intermediate_size: int = 11008 - no_bias: bool = True - vocab_size: int = 151936 - num_key_value_heads = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - hidden_size = args.hidden_size - self.num_attention_heads = args.num_attention_heads - - hidden_size_per_attention_head = hidden_size // self.num_attention_heads - - self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False) - - proj_size = args.kv_channels * self.num_attention_heads - - self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True) - self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias) - - self.scale = hidden_size_per_attention_head**-0.5 - - def __call__(self, x, mask=None, cache=None): - qkv = self.c_attn(x) - - q, k, v = mx.split(qkv, 3, axis=-1) - - B, L, _ = q.shape - - queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rotary_emb(queries, offset=cache.offset) - keys = self.rotary_emb(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rotary_emb(queries) - keys = self.rotary_emb(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.c_proj(output) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.w1 = nn.Linear( - args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias - ) - self.w2 = nn.Linear( - args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias - ) - self.c_proj = nn.Linear( - args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias - ) - - def __call__(self, x): - a1 = self.w1(x) - a2 = self.w2(x) - return self.c_proj(a1 * nn.silu(a2)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - self.attn = Attention(args) - self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - self.mlp = MLP(args) - - def __call__(self, x, mask=None, cache=None): - residual = x - x = self.ln_1(x) - x = self.attn(x, mask=mask, cache=cache) - residual = x + residual - x = self.ln_2(residual) - x = self.mlp(x) - x = x + residual - - return x - - -class QwenModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.wte = nn.Embedding(args.vocab_size, args.hidden_size) - self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] - self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, inputs, mask=None, cache=None): - x = self.wte(inputs) - - if mask is None: - mask = create_attention_mask(x, cache) - - if cache is None: - cache = [None] * len(self.h) - - for layer, c in zip(self.h, cache): - x = layer(x, mask, c) - - return self.ln_f(x) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model_type = config.model_type - self.transformer = QwenModel(config) - self.lm_head = nn.Linear( - config.hidden_size, config.vocab_size, bias=not config.no_bias - ) - self.args = config - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache=None, - ) -> mx.array: - y = self.transformer(x, mask, cache) - return self.lm_head(y) - - @property - def layers(self): - return self.transformer.h diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py deleted file mode 100644 index 381767c4..00000000 --- a/llms/mlx_lm/models/qwen2.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: Optional[int] = None - rope_theta: float = 1000000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = True - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class Qwen2Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = Qwen2Model(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - def sanitize(self, weights): - if self.args.tie_word_embeddings: - weights.pop("lm_head.weight", None) - # Remove unused precomputed rotary freqs - return { - k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k - } - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py deleted file mode 100644 index c6aba622..00000000 --- a/llms/mlx_lm/models/qwen2_moe.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from dataclasses import dataclass -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - num_experts_per_tok: int - num_experts: int - moe_intermediate_size: int - shared_expert_intermediate_size: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: Optional[int] = None - rope_theta: float = 1000000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - tie_word_embeddings: bool = False - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class Qwen2MoeSparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - dim = args.hidden_size - intermediate_size = args.moe_intermediate_size - shared_expert_intermediate_size = args.shared_expert_intermediate_size - - self.num_experts = num_experts = args.num_experts - self.top_k = args.num_experts_per_tok - - self.gate = nn.Linear(dim, num_experts, bias=False) - self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) - - self.shared_expert = MLP(dim, shared_expert_intermediate_size) - self.shared_expert_gate = nn.Linear(dim, 1, bias=False) - - def __call__( - self, - x: mx.array, - ): - gates = self.gate(x) - gates = mx.softmax(gates, axis=-1, precise=True) - - k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) - scores = mx.take_along_axis(gates, inds, axis=-1) - - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - - shared_expert_output = self.shared_expert(x) - shared_expert_output = ( - mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output - ) - - return y + shared_expert_output - - -class Qwen2MoeDecoderLayer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = Qwen2MoeSparseMoeBlock(args) - - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class Qwen2MoeModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - Qwen2MoeDecoderLayer(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = Qwen2MoeModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - return self.lm_head(out) - - def sanitize(self, weights): - if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: - return weights - for l in range(self.args.num_hidden_layers): - prefix = f"model.layers.{l}" - for n in ["up_proj", "down_proj", "gate_proj"]: - for k in ["weight", "scales", "biases"]: - if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: - to_join = [ - weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") - for e in range(self.args.num_experts) - ] - weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) - return weights - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py deleted file mode 100644 index ad07d925..00000000 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ /dev/null @@ -1,458 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from dataclasses import dataclass -from typing import List, Literal, Optional - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .cache import MambaCache, RotatingKVCache - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - attention_bias: bool - conv1d_width: int - hidden_size: int - intermediate_size: int - logits_soft_cap: float - num_attention_heads: int - num_hidden_layers: int - num_key_value_heads: int - rms_norm_eps: float - rope_theta: float - attention_window_size: int - vocab_size: int - embeddings_scale_by_sqrt_dim: bool = True - block_types: Optional[List[str]] = None - _block_types: Optional[List[str]] = None - - def __post_init__(self): - # For some reason these have different names in 2B and 9B - if self.block_types is None: - self.block_types = self._block_types - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def __call__(self, x): - return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) - - -def rnn_scan(x, a, h0): - assert x.ndim == 3 - assert a.shape == x.shape[-a.ndim :] - assert a.dtype == x.dtype - - if x.shape[1] == 1: - # Using scan in sampling mode. - if h0 is None: - return x, x[:, 0] - - else: - y = a * h0[:, None] + x - return y, y[:, -1] - - else: - # Using scan in linear mode. - if h0 is not None: - h_t = h0 - else: - B, _, D = x.shape - h_t = mx.zeros((B, D), dtype=x.dtype) - - y = mx.zeros_like(x) - for t in range(x.shape[1]): - h_t = a[:, t] * h_t + x[:, t] - y[:, t] = h_t - - return y, h_t - - -class Conv1d(nn.Module): - def __init__( - self, - channels: int, - kernel_size: int, - ): - super().__init__() - self.weight = mx.zeros((channels, kernel_size, 1)) - self.bias = mx.zeros((channels,)) - - def __call__(self, x, cache=None): - B, L, C = x.shape - groups, K, _ = self.weight.shape - - if cache is not None: - x = mx.concatenate([cache, x], axis=1) - else: - x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - - y = mx.conv_general(x, self.weight, groups=groups) - y = y + self.bias - - return y, x[:, -K + 1 :, :] - - -class RGLRU(nn.Module): - """A Real-Gated Linear Recurrent Unit (RG-LRU) layer.""" - - def __init__( - self, - width: int, - num_heads: int, - ): - super().__init__() - self.width = width - self.num_heads = num_heads - self.head_dim = self.width // self.num_heads - - self.recurrent_param = mx.zeros((self.width,)) - - self.input_gate_weight = mx.zeros( - (self.num_heads, self.head_dim, self.head_dim), - ) - self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim)) - - self.recurrent_gate_weight = mx.zeros( - (self.num_heads, self.head_dim, self.head_dim), - ) - self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim)) - - def __call__( - self, - x: mx.array, - cache=None, - ): - B, L, _ = x.shape - - def apply_block_linear(h, w, b): - h = h.reshape((B, L, self.num_heads, self.head_dim)) - h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b - return mx.sigmoid(h.flatten(2, 3)) - - # Gates for x and a. - gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias) - gate_a = apply_block_linear( - x, self.recurrent_gate_weight, self.recurrent_gate_bias - ) - - # Compute the parameter `A` of the recurrence. - log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param) - a = mx.exp(log_a) - a_square = mx.exp(2 * log_a) - - # Gate the input. - gated_x = x * gate_x - - # Apply gamma normalization to the input. - multiplier = mx.sqrt(1 - a_square) - if cache is None: - multiplier[:, 0, :] = 1.0 - normalized_x = gated_x * multiplier.astype(x.dtype) - - y, last_h = rnn_scan( - x=normalized_x, - a=a, - h0=cache, - ) - - return y, last_h - - -class RecurrentBlock(nn.Module): - - def __init__( - self, - width: int, - num_heads: int, - lru_width: int = None, - conv1d_temporal_width: int = 4, - ): - super().__init__() - self.width = width - self.num_heads = num_heads - self.lru_width = lru_width or width - self.conv1d_temporal_width = conv1d_temporal_width - - self.linear_y = nn.Linear(width, self.lru_width) - self.linear_x = nn.Linear(width, self.lru_width) - self.linear_out = nn.Linear(self.lru_width, width) - self.conv_1d = Conv1d( - channels=self.lru_width, - kernel_size=self.conv1d_temporal_width, - ) - self.rg_lru = RGLRU( - width=self.lru_width, - num_heads=self.num_heads, - ) - - def __call__( - self, - x: mx.array, - cache=None, - mask=None, - ): - # y branch. - y = self.linear_y(x) - y = nn.gelu_approx(y) - - # x branch. - x = self.linear_x(x) - if cache is None: - cache = [None, None] - x, cache[0] = self.conv_1d(x=x, cache=cache[0]) - x, cache[1] = self.rg_lru(x=x, cache=cache[1]) - - x = x * y - x = self.linear_out(x) - - return x - - -class LocalAttentionBlock(nn.Module): - - def __init__( - self, - width: int, - num_heads: int, - window_size: int, - ): - super().__init__() - self.width = width - self.num_heads = num_heads - self.window_size = window_size - self.scale = (width // num_heads) ** (-0.5) - - self.head_dim = self.width // self.num_heads - self.q_proj = nn.Linear(self.width, self.width, bias=False) - self.k_proj = nn.Linear(self.width, self.head_dim, bias=False) - self.v_proj = nn.Linear(self.width, self.head_dim, bias=False) - self.o_proj = nn.Linear(self.width, self.width, bias=True) - self.rope = nn.RoPE( - self.head_dim // 2, - traditional=False, - ) - - def __call__( - self, - x: mx.array, - cache=None, - mask=None, - ): - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLPBlock(nn.Module): - - def __init__(self, width: int, expanded_width: int): - super().__init__() - self.up_proj = nn.Linear(width, expanded_width // 2) - self.gate_proj = nn.Linear(width, expanded_width // 2) - self.down_proj = nn.Linear(expanded_width // 2, width) - - def __call__(self, x: mx.array): - gate = self.gate_proj(x) - x = self.up_proj(x) - return self.down_proj(nn.gelu_approx(gate) * x) - - -class ResidualBlock(nn.Module): - - def __init__( - self, - width: int, - mlp_expanded_width: int, - num_heads: int, - attention_window_size: int, - temporal_block_type: str, - lru_width: Optional[int] = None, - conv1d_temporal_width: int = 4, - ): - """Initializes the residual block. - - Args: - width: The width of the block. - mlp_expanded_width: The width of the expansion inside the MLP block. - num_heads: The number of heads for the Attention or the RG-LRU. - attention_window_size: The window size for the local attention block. - temporal_block_type: Either "recurrent" or "attention", specifying the - type of recurrent block to use. - lru_width: The width of the RG-LRU if different from `width`. - conv1d_temporal_width: The width of the temporal convolution. - """ - super().__init__() - self.width = width - self.mlp_expanded_width = mlp_expanded_width - self.num_heads = num_heads - self.attention_window_size = attention_window_size - self.temporal_block_type = temporal_block_type - self.lru_width = lru_width - self.conv1d_temporal_width = conv1d_temporal_width - - self.temporal_pre_norm = RMSNorm(width) - if self.temporal_block_type == "recurrent": - self.temporal_block = RecurrentBlock( - width=self.width, - num_heads=self.num_heads, - lru_width=self.lru_width, - conv1d_temporal_width=self.conv1d_temporal_width, - ) - - else: - self.temporal_block = LocalAttentionBlock( - width=self.width, - num_heads=self.num_heads, - window_size=self.attention_window_size, - ) - - self.channel_pre_norm = RMSNorm(width) - self.mlp_block = MLPBlock( - width=self.width, - expanded_width=self.mlp_expanded_width, - ) - - def __call__( - self, - x: mx.array, - cache=None, - mask=None, - ): - raw_x = x - - inputs_normalized = self.temporal_pre_norm(raw_x) - - x = self.temporal_block(inputs_normalized, cache=cache, mask=mask) - residual = x + raw_x - - x = self.channel_pre_norm(residual) - x = self.mlp_block(x) - - x = x + residual - - return x - - -class Griffin(nn.Module): - def __init__(self, config): - super().__init__() - - self.config = config - self.embed_tokens = nn.Embedding( - config.vocab_size, - config.hidden_size, - ) - - self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim - block_types = config.block_types - - self.layers = [ - ResidualBlock( - width=config.hidden_size, - mlp_expanded_width=config.intermediate_size, - num_heads=config.num_attention_heads, - attention_window_size=config.attention_window_size, - temporal_block_type=block_types[i % len(block_types)], - lru_width=None, - ) - for i in range(config.num_hidden_layers) - ] - self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def __call__( - self, - tokens, - mask: mx.array = None, - cache=None, - ): - x = self.embed_tokens(tokens) - if self.scale_by_sqrt_dim: - x = x * math.sqrt(x.shape[-1]) - - if cache is None: - cache = [None] * len(self.layers) - - for i, block in enumerate(self.layers): - if block.temporal_block_type != "recurrent": - mask_cache = [cache[i]] - - if mask is None: - mask = create_attention_mask(x, mask_cache) - - for i, block in enumerate(self.layers): - x = block(x, mask=mask, cache=cache[i]) - - return self.final_norm(x) - - -class Model(nn.Module): - - def __init__(self, config): - self.args = config - self.model = Griffin(config) - self.model_type = config.model_type - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array: - """ - Args: - tokens: Sequence of input tokens. - """ - logits = self.model(tokens, mask=mask, cache=cache) - if "lm_head" in self: - logits = self.lm_head(logits) - else: - logits = self.model.embed_tokens.as_linear(logits) - - c = self.args.logits_soft_cap - if c: - logits = mx.tanh(logits / c) * c - return logits - - @property - def layers(self): - return self.model.layers - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv_1d.weight" in k and v.shape[-1] != 1: - weights[k] = v.moveaxis(2, 1) - if "lm_head.weight" not in weights: - self.pop("lm_head") - return weights - - def make_cache(self): - cache = [] - for layer in self.layers: - if layer.temporal_block_type == "recurrent": - cache.append(MambaCache()) - else: - cache.append(RotatingKVCache(max_size=self.args.attention_window_size)) - return cache diff --git a/llms/mlx_lm/models/rope_utils.py b/llms/mlx_lm/models/rope_utils.py deleted file mode 100644 index d30b432d..00000000 --- a/llms/mlx_lm/models/rope_utils.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from typing import Optional - -import mlx.core as mx -import mlx.nn as nn - - -class Llama3RoPE(nn.Module): - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scaling_config: dict = None, - ): - super().__init__() - self.dims = dims - self.max_position_embeddings = max_position_embeddings - self.traditional = traditional - - factor = scaling_config["factor"] - low_freq_factor = scaling_config.get("low_freq_factor", 1.0) - high_freq_factor = scaling_config.get("high_freq_factor", 4.0) - old_context_len = scaling_config.get( - "original_max_position_embeddings", - 8192, - ) - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - freqs = base ** (mx.arange(0, dims, 2) / dims) - wavelens = 2 * mx.pi * freqs - - freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) - is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) - smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) - self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) - - def extra_repr(self): - return ( - f"{self.dims}, traditional={self.traditional}, " - f"max_position_embeddings={self.max_position_embeddings}" - ) - - def __call__(self, x, offset: int = 0): - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=None, - scale=1.0, - offset=offset, - freqs=self._freqs, - ) - - -def initialize_rope( - dims, - base, - traditional, - scaling_config: Optional[dict] = None, - max_position_embeddings: Optional[int] = None, -): - if scaling_config is not None: - rope_type = scaling_config.get("type") or scaling_config.get( - "rope_type", "default" - ) - else: - rope_type = "default" - - if rope_type in ["default", "linear"]: - scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0 - return nn.RoPE(dims, traditional=traditional, base=base, scale=scale) - - elif rope_type == "llama3": - return Llama3RoPE( - dims=dims, - max_position_embeddings=max_position_embeddings, - traditional=traditional, - base=base, - scaling_config=scaling_config, - ) - else: - raise ValueError(f"Unsupported RoPE type {rope_type}") diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py deleted file mode 100644 index 0bbc2ca4..00000000 --- a/llms/mlx_lm/models/stablelm.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from dataclasses import dataclass - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - num_attention_heads: int - num_hidden_layers: int - num_key_value_heads: int - intermediate_size: int - rope_theta: float - use_qkv_bias: bool - partial_rotary_factor: float - layer_norm_eps: float - use_parallel_residual: bool = False - qk_layernorm: bool = False - - -class LayerNormPerHead(nn.Module): - - def __init__(self, head_dim, num_heads, eps): - super().__init__() - self.norms = [ - nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads) - ] - self.eps = eps - - def __call__(self, x): - w = mx.stack([n.weight for n in self.norms]) - return w * mx.fast.layer_norm(x, None, None, self.eps) - - -class Attention(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.rope_theta = config.rope_theta - self.partial_rotary_factor = config.partial_rotary_factor - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=config.use_qkv_bias, - ) - self.v_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=config.use_qkv_bias, - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - self.rope = nn.RoPE( - int(self.partial_rotary_factor * self.head_dim), - traditional=False, - base=self.rope_theta, - ) - - self.qk_layernorm = config.qk_layernorm - if self.qk_layernorm: - self.q_layernorm = LayerNormPerHead( - self.head_dim, self.num_heads, eps=config.layer_norm_eps - ) - self.k_layernorm = LayerNormPerHead( - self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps - ) - - def __call__(self, x, mask=None, cache=None): - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Extract some shapes - B, L, D = queries.shape - - queries = queries.reshape(B, L, self.num_heads, -1) - keys = keys.reshape(B, L, self.num_key_value_heads, -1) - if self.qk_layernorm: - queries = self.q_layernorm(queries) - keys = self.k_layernorm(keys) - queries = queries.transpose(0, 2, 1, 3) - keys = keys.transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( - 0, 2, 1, 3 - ) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - queries = queries.astype(mx.float32) - keys = keys.astype(mx.float32) - - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=scale, mask=mask - ).astype(values.dtype) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class DecoderLayer(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.self_attn = Attention(config=config) - self.mlp = MLP(config.hidden_size, config.intermediate_size) - self.input_layernorm = nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - ) - self.use_parallel_residual = config.use_parallel_residual - if not self.use_parallel_residual: - self.post_attention_layernorm = nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - ) - - def __call__(self, x, mask, cache): - h = self.input_layernorm(x) - r = self.self_attn(h, mask, cache) - - if self.use_parallel_residual: - out = x + r + self.mlp(h) - else: - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class StableLM(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)] - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def __call__(self, x, mask, cache): - x = self.embed_tokens(x) - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - x = layer(x, mask, cache=c) - - return self.norm(x) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model_type = config.model_type - self.model = StableLM(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.args = config - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache=None, - ) -> mx.array: - - if mask is None: - mask = create_attention_mask(x, cache) - - y = self.model(x, mask, cache) - return self.lm_head(y) - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py deleted file mode 100644 index 71c397f6..00000000 --- a/llms/mlx_lm/models/starcoder2.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -from dataclasses import dataclass -from typing import Any, Optional - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - num_key_value_heads: int - norm_epsilon: float = 1e-5 - vocab_size: int = 49152 - rope_theta: float = 100000 - tie_word_embeddings: bool = True - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // args.num_attention_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) - self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.c_fc = nn.Linear(dim, hidden_dim, bias=True) - self.c_proj = nn.Linear(hidden_dim, dim, bias=True) - - def __call__(self, x): - return self.c_proj(nn.gelu(self.c_fc(x))) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - self.n_heads = args.num_attention_heads - - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) - self.post_attention_layernorm = nn.LayerNorm( - args.hidden_size, eps=args.norm_epsilon - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out - - -class Starcoder2Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - h = self.embed_tokens(inputs) - - if mask is None: - mask = create_attention_mask(h, cache) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) - - return self.norm(h) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = Starcoder2Model(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - mask: mx.array = None, - cache=None, - ): - out = self.model(inputs, mask, cache) - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.lm_head(out) - return out - - @property - def layers(self): - return self.model.layers diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py deleted file mode 100644 index 6340c77b..00000000 --- a/llms/mlx_lm/models/su_rope.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from typing import List, Union - -import mlx.core as mx -import mlx.nn as nn - - -class SuScaledRotaryEmbedding(nn.Module): - def __init__( - self, - dims: int, - base: float = 10000.0, - max_position_embeddings: int = 131072, - original_max_position_embeddings: int = 4096, - short_factor: Union[List[float], float] = 1.0, - long_factor: Union[List[float], float] = 1.0, - short_mscale: float = None, - long_mscale: float = None, - ): - """ - Phi3Su Scaled Rotary Embedding layer for Phi-3 models. - - Args: - dims (int): The feature dimensions to be rotated. - base (int, optional): Base for the exponential scaling. - max_position_embeddings (int, optional): The maximum sequence - length that this model was trained with. This is used to determine - the size of the original RoPE embeddings when using long scaling. - Default: ``131072``. - original_max_position_embeddings (int, optional): The maximum - sequence length that this model was trained with. This is used to - determine the size of the original RoPE embeddings when using long - scaling. Default: ``4096``. - short_factor (float or list[float], optional): List of scaling - factors for sequences of length lesser than - ``original_max_position_embeddings``. Default: ``1.0``. - long_factor (float or list[float], optional): List of scaling - factors for sequences of length greater than - ``original_max_position_embeddings``. Default: ``1.0``. - short_mscale (float, optional): Scale the input prior to embedding. - long_mscale (float, optional): Scale the input prior to embedding. - """ - super().__init__() - freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) - self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs - self.original_max_position_embeddings = original_max_position_embeddings - self.scale = long_mscale or math.sqrt( - 1 - + math.log(max_position_embeddings / original_max_position_embeddings) - / math.log(original_max_position_embeddings) - ) - self.dim = dims - - def __call__(self, x, offset: int = 0): - x[..., : self.dim] = self.scale * x[..., : self.dim] - return mx.fast.rope( - x, - self.dim, - traditional=False, - base=None, - scale=1.0, - offset=offset, - freqs=self._freqs, - ) diff --git a/llms/mlx_lm/models/switch_layers.py b/llms/mlx_lm/models/switch_layers.py deleted file mode 100644 index 4a157473..00000000 --- a/llms/mlx_lm/models/switch_layers.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math - -import mlx.core as mx -import mlx.nn as nn - - -class QuantizedSwitchLinear(nn.Module): - def __init__( - self, - input_dims: int, - output_dims: int, - num_experts: int, - bias: bool = True, - group_size: int = 64, - bits: int = 4, - ): - super().__init__() - - scale = math.sqrt(1 / input_dims) - self.weight, self.scales, self.biases = mx.quantize( - mx.random.uniform( - low=-scale, - high=scale, - shape=(num_experts, output_dims, input_dims), - ), - group_size=group_size, - bits=bits, - ) - - if bias: - self.bias = mx.zeros((num_experts, output_dims)) - - self.group_size = group_size - self.bits = bits - - # Freeze this model's parameters - self.freeze() - - def unfreeze(self, *args, **kwargs): - """Wrap unfreeze so that we unfreeze any layers we might contain but - our parameters will remain frozen.""" - super().unfreeze(*args, **kwargs) - self.freeze(recurse=False) - - @property - def input_dims(self): - return self.scales.shape[2] * self.group_size - - @property - def output_dims(self): - return self.weight.shape[1] - - @property - def num_experts(self): - return self.weight.shape[0] - - def __call__(self, x, indices): - x = mx.gather_qmm( - x, - self["weight"], - self["scales"], - self["biases"], - rhs_indices=indices, - transpose=True, - group_size=self.group_size, - bits=self.bits, - ) - if "bias" in self: - x = x + mx.expand_dims(self["bias"][indices], -2) - return x - - -class SwitchLinear(nn.Module): - def __init__( - self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True - ): - super().__init__() - scale = math.sqrt(1 / input_dims) - self.weight = mx.random.uniform( - low=-scale, - high=scale, - shape=(num_experts, output_dims, input_dims), - ) - - if bias: - self.bias = mx.zeros((num_experts, output_dims)) - - @property - def input_dims(self): - return self.weight.shape[2] - - @property - def output_dims(self): - return self.weight.shape[1] - - @property - def num_experts(self): - return self.weight.shape[0] - - def __call__(self, x, indices): - x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices) - if "bias" in self: - x = x + mx.expand_dims(self["bias"][indices], -2) - return x - - def to_quantized(self, group_size: int = 64, bits: int = 4): - num_experts, output_dims, input_dims = self.weight.shape - ql = QuantizedSwitchLinear( - input_dims, output_dims, num_experts, False, group_size, bits - ) - ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits) - if "bias" in self: - ql.bias = self.bias - return ql - - -class SwitchGLU(nn.Module): - def __init__( - self, - input_dims: int, - hidden_dims: int, - num_experts: int, - activation=nn.silu, - bias: bool = False, - ): - super().__init__() - - self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) - self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) - self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) - self.activation = activation - - def __call__(self, x, indices) -> mx.array: - x = mx.expand_dims(x, (-2, -3)) - - x_up = self.up_proj(x, indices) - x_gate = self.gate_proj(x, indices) - x = self.down_proj(self.activation(x_gate) * x_up, indices) - - return x.squeeze(-2) - - -class SwitchMLP(nn.Module): - def __init__( - self, - input_dims: int, - hidden_dims: int, - num_experts: int, - activation=nn.gelu_approx, - bias: bool = False, - ): - super().__init__() - - self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) - self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) - self.activation = activation - - def __call__(self, x, indices) -> mx.array: - x = mx.expand_dims(x, (-2, -3)) - - x = self.fc1(x, indices) - x = self.activation(x) - x = self.fc2(x, indices) - - return x.squeeze(-2) diff --git a/llms/mlx_lm/py.typed b/llms/mlx_lm/py.typed deleted file mode 100644 index 8b137891..00000000 --- a/llms/mlx_lm/py.typed +++ /dev/null @@ -1 +0,0 @@ - diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt deleted file mode 100644 index 72e1ef89..00000000 --- a/llms/mlx_lm/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -mlx>=0.22.0 -numpy -transformers[sentencepiece]>=4.39.3 -protobuf -pyyaml -jinja2 diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py deleted file mode 100644 index efc5b556..00000000 --- a/llms/mlx_lm/sample_utils.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import math -from functools import partial -from typing import Callable, Dict, Optional - -import mlx.core as mx - - -def make_sampler( - temp: float = 0.0, - top_p: float = 0.0, - min_p: float = 0.0, - min_tokens_to_keep: int = 1, - top_k: int = -1, -) -> Callable[mx.array, mx.array]: - """ - Make a sampler function for use with ``generate_step``. - - Args: - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. - top_k (int, optional): The top k tokens ranked by probability to constrain - the sampling to. - - Returns: - Callable[mx.array, mx.array]: - A sampler which takes log-probabilities and returns tokens. - """ - if temp == 0: - return lambda x: mx.argmax(x, axis=-1) - - # Create sampler chain - sampling_methods = [] - if top_k > 0: - sampling_methods.append(lambda x: apply_top_k(x, top_k)) - if top_p > 0 and top_p < 1.0: - sampling_methods.append(lambda x: apply_top_p(x, top_p)) - if min_p != 0.0: - sampling_methods.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep)) - - # Apply the sampling methods - def sampler(logits): - for method in sampling_methods: - logits = method(logits) - - # Return the sampled token - return categorical_sampling(logits, temp) - - return sampler - - -def make_logits_processors( - logit_bias: Optional[Dict[int, float]] = None, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = 20, -): - """ - Make logits processors for use with ``generate_step``. - - Args: - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - logit_bias (dictionary, optional): Additive logit bias. - - Returns: - List[Callable[[mx.array, mx.array], mx.array]]: - A list of logits processors. Each processor in the list is a - callable which takes an array of tokens and an array of logits - and returns the updated logits. - """ - logits_processors = [] - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - - def logit_bias_processor(_, logits): - logits[:, indices] += values - return logits - - logits_processors.append(logit_bias_processor) - - if repetition_penalty and repetition_penalty != 0.0: - logits_processors.append( - make_repetition_penalty(repetition_penalty, repetition_context_size) - ) - return logits_processors - - -@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def apply_top_k( - logprobs: mx.array, - top_k: int, -) -> mx.array: - """ - Sample from only the top K tokens ranked by probability. - - Args: - logprobs: A vector of log probabilities. - top_k (int): Top k tokens to sample from. - """ - vocab_size = logprobs.shape[-1] - if not isinstance(top_k, int) or not (0 < top_k < vocab_size): - raise ValueError( - f"`top_k` has to be an integer in the (0, {vocab_size}] interval," - f" but is {top_k}." - ) - mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] - masked_logprobs = mx.put_along_axis( - logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1 - ) - return masked_logprobs - - -@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def apply_min_p( - logprobs: mx.array, - min_p: float, - min_tokens_to_keep: int = 1, -) -> mx.array: - """ - Apply min-p sampling to the logprobs. - - Min-p keeps all tokens that are above a minimum probability, scaled by the - probability of the most likely token. As a result, the filter is more - aggressive given a very high-probability token. - - Args: - logprobs: A vector of log probabilities. - min_p (float): Minimum token probability. Typical values are in the - 0.01-0.2 range, comparably selective as setting `top_p` in the - 0.99-0.8 range. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered. Default: ``1``. - - """ - if not (0 <= min_p <= 1.0): - raise ValueError( - f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}" - ) - if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): - raise ValueError( - f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}" - ) - # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 - - # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logprobs, axis=-1) - sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1) - - # Top probability - top_logprobs = sorted_logprobs[:, 0:1] - - # Calculate the min_p threshold - scaled_min_p = top_logprobs + math.log(min_p) - - # Mask tokens that have a probability less than the scaled min_p - tokens_to_remove = sorted_logprobs < scaled_min_p - tokens_to_remove[..., :min_tokens_to_keep] = False - - # Create pool of tokens with probability less than scaled min_p - selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) - - # Create a mapping to rearrange back to original indices - # Use argsort of sorted_indices to get the inverse permutation - inverse_indices = mx.argsort(sorted_indices, axis=-1) - - # Rearrange selected_logprobs back to original order - original_order_logprobs = mx.take_along_axis( - selected_logprobs, inverse_indices, axis=-1 - ) - - return original_order_logprobs - - -@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def apply_top_p(logits: mx.array, top_p: float) -> mx.array: - """ - Apply top-p (nucleus) sampling to logits. - - Args: - logits: The logits from the model's output. - top_p: The cumulative probability threshold for top-p filtering. - Returns: - token selected based on the top-p criterion. - """ - # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits, axis=-1) - - # sort probs in ascending order - sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) - - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) - - # select tokens with cumulative probs below threshold - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - 0, - ) - - # Create a mapping to rearrange back to original indices - # Use argsort of sorted_indices to get the inverse permutation - inverse_indices = mx.argsort(sorted_indices, axis=-1) - - # Rearrange top_probs back to original order - original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1) - - # Convert back to logits and return - return mx.log(original_order_probs) - - -@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def categorical_sampling(logits, temp): - return mx.random.categorical(logits * (1 / temp)) - - -def make_repetition_penalty(penalty: float, context_size: int = 20): - """ - Make repetition penalty processor. - - Paper: https://arxiv.org/abs/1909.05858 - - Args: - penalty (float): The repetition penalty factor to be applied. - context_size (int): The number of previous tokens to use. - Default: ``20``. - - Returns: - Callable[[mx.array, List[int]], mx.array]: - The repetition penalty processor. - """ - if penalty < 0 or not isinstance(penalty, (int, float)): - raise ValueError(f"penalty must be a non-negative float, got {penalty}") - - def repetition_penalty_processor(tokens, logits): - if len(tokens) > 0: - tokens = tokens[-context_size:] - selected_logits = logits[:, tokens] - selected_logits = mx.where( - selected_logits < 0, - selected_logits * penalty, - selected_logits / penalty, - ) - logits[:, tokens] = selected_logits - return logits - - return repetition_penalty_processor diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py deleted file mode 100644 index de02704d..00000000 --- a/llms/mlx_lm/server.py +++ /dev/null @@ -1,785 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import argparse -import json -import logging -import platform -import time -import uuid -import warnings -from dataclasses import dataclass, field -from http.server import BaseHTTPRequestHandler, HTTPServer -from pathlib import Path -from typing import ( - Any, - Dict, - List, - Literal, - NamedTuple, - Optional, - Sequence, - Tuple, - Union, -) - -import mlx.core as mx -from huggingface_hub import scan_cache_dir - -from ._version import __version__ -from .models.cache import make_prompt_cache -from .sample_utils import make_logits_processors, make_sampler -from .utils import load, stream_generate - - -def get_system_fingerprint(): - gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else "" - return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" - - -class StopCondition(NamedTuple): - stop_met: bool - trim_length: int - - -def stopping_criteria( - tokens: List[int], - stop_id_sequences: List[List[int]], - eos_token_id: Union[int, None], -) -> StopCondition: - """ - Determines whether the token generation should stop based on predefined - conditions. - - Args: - tokens (List[int]): The current sequence of generated tokens. - stop_id_sequences (List[List[[int]]): A list of integer lists, each - representing a sequence of token IDs. If the end of the `tokens` - list matches any of these sequences, the generation should stop. - eos_token_id (Union[int, None]): The token ID that represents the - end-of-sequence. If the last token in `tokens` matches this, the - generation should stop. - - Returns: - StopCondition: A named tuple indicating whether the stop condition has - been met (`stop_met`) and how many tokens should be trimmed from the - end if it has (`trim_length`). - """ - if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=0) - - for stop_ids in stop_id_sequences: - if len(tokens) >= len(stop_ids): - if tokens[-len(stop_ids) :] == stop_ids: - return StopCondition(stop_met=True, trim_length=len(stop_ids)) - - return StopCondition(stop_met=False, trim_length=0) - - -def sequence_overlap(s1: Sequence, s2: Sequence) -> bool: - """ - Checks if a suffix of s1 has overlap with a prefix of s2 - - Args: - s1 (Sequence): The first sequence - s2 (Sequence): The second sequence - - Returns: - bool: If the two sequences have overlap - """ - max_overlap = min(len(s1), len(s2)) - return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1)) - - -def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): - default_role_mapping = { - "system_prompt": ( - "A chat between a curious user and an artificial intelligence " - "assistant. The assistant follows the given rules no matter what." - ), - "system": "ASSISTANT's RULE: ", - "user": "USER: ", - "assistant": "ASSISTANT: ", - "stop": "\n", - } - role_mapping = role_mapping if role_mapping is not None else default_role_mapping - - prompt = "" - for line in messages: - role_prefix = role_mapping.get(line["role"], "") - stop = role_mapping.get("stop", "") - content = line.get("content", "") - prompt += f"{role_prefix}{content}{stop}" - - prompt += role_mapping.get("assistant", "") - return prompt.rstrip() - - -def process_message_content(messages): - """ - Convert message content to a format suitable for `apply_chat_template`. - - The function operates on messages in place. It converts the 'content' field - to a string instead of a list of text fragments. - - Args: - message_list (list): A list of dictionaries, where each dictionary may - have a 'content' key containing a list of dictionaries with 'type' and - 'text' keys. - - Raises: - ValueError: If the 'content' type is not supported or if 'text' is missing. - - """ - for message in messages: - content = message["content"] - if isinstance(content, list): - text_fragments = [ - fragment["text"] for fragment in content if fragment["type"] == "text" - ] - if len(text_fragments) != len(content): - raise ValueError("Only 'text' content type is supported.") - message["content"] = "".join(text_fragments) - - -@dataclass -class PromptCache: - cache: List[Any] = field(default_factory=list) - model_key: Tuple[str, Optional[str]] = ("", None) - tokens: List[int] = field(default_factory=list) - - -class ModelProvider: - def __init__(self, cli_args: argparse.Namespace): - """Load models on demand and persist them across the whole process.""" - self.cli_args = cli_args - self.model_key = None - self.model = None - self.tokenizer = None - - # Preload the default model if it is provided - if self.cli_args.model is not None: - self.load("default_model") - - def _validate_model_path(self, model_path: str): - model_path = Path(model_path) - if model_path.exists() and not model_path.is_relative_to(Path.cwd()): - raise RuntimeError( - "Local models must be relative to the current working dir." - ) - - # Added in adapter_path to load dynamically - def load(self, model_path, adapter_path=None): - if self.model_key == (model_path, adapter_path): - return self.model, self.tokenizer - - # Remove the old model if it exists. - self.model = None - self.tokenizer = None - self.model_key = None - - # Building tokenizer_config - tokenizer_config = { - "trust_remote_code": True if self.cli_args.trust_remote_code else None - } - if self.cli_args.chat_template: - tokenizer_config["chat_template"] = self.cli_args.chat_template - - if model_path == "default_model" and self.cli_args.model is not None: - model, tokenizer = load( - self.cli_args.model, - adapter_path=( - adapter_path if adapter_path else self.cli_args.adapter_path - ), # if the user doesn't change the model but adds an adapter path - tokenizer_config=tokenizer_config, - ) - else: - self._validate_model_path(model_path) - model, tokenizer = load( - model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config - ) - - if self.cli_args.use_default_chat_template: - if tokenizer.chat_template is None: - tokenizer.chat_template = tokenizer.default_chat_template - - self.model_key = (model_path, adapter_path) - self.model = model - self.tokenizer = tokenizer - - return self.model, self.tokenizer - - -class APIHandler(BaseHTTPRequestHandler): - def __init__( - self, - model_provider: ModelProvider, - *args, - prompt_cache: Optional[PromptCache] = None, - system_fingerprint: Optional[str] = None, - **kwargs, - ): - """ - Create static request specific metadata - """ - self.created = int(time.time()) - self.model_provider = model_provider - self.prompt_cache = prompt_cache or PromptCache() - self.system_fingerprint = system_fingerprint or get_system_fingerprint() - super().__init__(*args, **kwargs) - - def _set_cors_headers(self): - self.send_header("Access-Control-Allow-Origin", "*") - self.send_header("Access-Control-Allow-Methods", "*") - self.send_header("Access-Control-Allow-Headers", "*") - - def _set_completion_headers(self, status_code: int = 200): - self.send_response(status_code) - self.send_header("Content-type", "application/json") - self._set_cors_headers() - - def _set_stream_headers(self, status_code: int = 200): - self.send_response(status_code) - self.send_header("Content-type", "text/event-stream") - self.send_header("Cache-Control", "no-cache") - self._set_cors_headers() - - def do_OPTIONS(self): - self._set_completion_headers(204) - self.end_headers() - - def do_POST(self): - """ - Respond to a POST request from a client. - """ - endpoints = { - "/v1/completions": self.handle_text_completions, - "/v1/chat/completions": self.handle_chat_completions, - "/chat/completions": self.handle_chat_completions, - } - - if self.path not in endpoints: - self._set_completion_headers(404) - self.end_headers() - self.wfile.write(b"Not Found") - return - - # Fetch and parse request body - content_length = int(self.headers["Content-Length"]) - raw_body = self.rfile.read(content_length) - self.body = json.loads(raw_body.decode()) - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}") - assert isinstance( - self.body, dict - ), f"Request should be dict, but got {type(self.body)}" - - # Extract request parameters from the body - self.stream = self.body.get("stream", False) - self.stream_options = self.body.get("stream_options", None) - self.requested_model = self.body.get("model", "default_model") - self.adapter = self.body.get("adapters", None) - self.max_tokens = self.body.get("max_completion_tokens", None) - if self.max_tokens is None: - self.max_tokens = self.body.get("max_tokens", 512) - self.temperature = self.body.get("temperature", 0.0) - self.top_p = self.body.get("top_p", 1.0) - self.repetition_penalty = self.body.get("repetition_penalty", 1.0) - self.repetition_context_size = self.body.get("repetition_context_size", 20) - self.logit_bias = self.body.get("logit_bias", None) - self.logprobs = self.body.get("logprobs", -1) - self.validate_model_parameters() - - # Load the model if needed - try: - self.model, self.tokenizer = self.model_provider.load( - self.requested_model, self.adapter - ) - except: - self._set_completion_headers(404) - self.end_headers() - self.wfile.write(b"Not Found") - return - - # Get stop id sequences, if provided - stop_words = self.body.get("stop") - stop_words = stop_words or [] - stop_words = [stop_words] if isinstance(stop_words, str) else stop_words - stop_id_sequences = [ - self.tokenizer.encode(stop_word, add_special_tokens=False) - for stop_word in stop_words - ] - - # Send header type - ( - self._set_stream_headers(200) - if self.stream - else self._set_completion_headers(200) - ) - - # Call endpoint specific method - prompt = endpoints[self.path]() - self.handle_completion(prompt, stop_id_sequences) - - def validate_model_parameters(self): - """ - Validate the model parameters passed in the request for the correct types and values. - """ - if not isinstance(self.stream, bool): - raise ValueError("stream must be a boolean") - - if not isinstance(self.max_tokens, int) or self.max_tokens < 0: - raise ValueError("max_tokens must be a non-negative integer") - - if not isinstance(self.temperature, (float, int)) or self.temperature < 0: - raise ValueError("temperature must be a non-negative float") - - if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1: - raise ValueError("top_p must be a float between 0 and 1") - - if ( - not isinstance(self.repetition_penalty, (float, int)) - or self.repetition_penalty < 0 - ): - raise ValueError("repetition_penalty must be a non-negative float") - - if self.logprobs != -1 and not (0 < self.logprobs <= 10): - raise ValueError( - f"logprobs must be between 1 and 10 but got {self.logprobs:,}" - ) - - if ( - not isinstance(self.repetition_context_size, int) - or self.repetition_context_size < 0 - ): - raise ValueError("repetition_context_size must be a non-negative integer") - - if self.logit_bias is not None: - if not isinstance(self.logit_bias, dict): - raise ValueError("logit_bias must be a dict of int to float") - - try: - self.logit_bias = {int(k): v for k, v in self.logit_bias.items()} - except ValueError: - raise ValueError("logit_bias must be a dict of int to float") - - if not isinstance(self.requested_model, str): - raise ValueError("model must be a string") - if self.adapter is not None and not isinstance(self.adapter, str): - raise ValueError("adapter must be a string") - - def generate_response( - self, - text: str, - finish_reason: Union[Literal["length", "stop"], None], - prompt_token_count: Optional[int] = None, - completion_token_count: Optional[int] = None, - token_logprobs: Optional[List[float]] = None, - top_tokens: Optional[List[Dict[int, float]]] = None, - tokens: Optional[List[int]] = None, - ) -> dict: - """ - Generate a single response packet based on response type (stream or - not), completion type and parameters. - - Args: - text (str): Text generated by model - finish_reason (Union[Literal["length", "stop"], None]): The reason the - response is being sent: "length", "stop" or `None`. - prompt_token_count (Optional[int]): The number of tokens in the prompt, - used to populate the "usage" field (not used when stream). - completion_token_count (Optional[int]): The number of tokens in the - response, used to populate the "usage" field (not used when stream). - token_logprobs (Optional[List[float]]): The log probabilities per token, - in token order. - top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping - tokens to logprobs for the top N tokens at each token position. - tokens (Optional[List[int]]): List of tokens to return with logprobs structure - - Returns: - dict: A dictionary containing the response, in the same format as - OpenAI's API. - """ - token_logprobs = token_logprobs if token_logprobs else [] - top_logprobs = top_tokens if top_tokens else [] - - # Static response - response = { - "id": self.request_id, - "system_fingerprint": self.system_fingerprint, - "object": self.object_type, - "model": self.requested_model, - "created": self.created, - "choices": [ - { - "index": 0, - "logprobs": { - "token_logprobs": token_logprobs, - "top_logprobs": top_logprobs, - "tokens": tokens, - }, - "finish_reason": finish_reason, - } - ], - } - - if not self.stream: - if not ( - isinstance(prompt_token_count, int) - and isinstance(completion_token_count, int) - ): - raise ValueError( - "Response type is complete, but token counts not provided" - ) - - response["usage"] = { - "prompt_tokens": prompt_token_count, - "completion_tokens": completion_token_count, - "total_tokens": prompt_token_count + completion_token_count, - } - - choice = response["choices"][0] - - # Add dynamic response - if self.object_type.startswith("chat.completion"): - key_name = "delta" if self.stream else "message" - choice[key_name] = {"role": "assistant", "content": text} - elif self.object_type == "text_completion": - choice.update(text=text) - else: - ValueError(f"Unsupported response type: {self.object_type}") - - return response - - def get_prompt_cache(self, prompt): - cache_len = len(self.prompt_cache.tokens) - if ( - self.prompt_cache.model_key != self.model_provider.model_key - or cache_len >= len(prompt) - or self.prompt_cache.tokens != prompt[:cache_len] - ): - self.prompt_cache.model_key = self.model_provider.model_key - self.prompt_cache.cache = make_prompt_cache(self.model_provider.model) - else: - prompt = prompt[cache_len:] - self.prompt_cache.tokens.extend(prompt) - return prompt - - def handle_completion( - self, - prompt: List[int], - stop_id_sequences: List[List[int]], - ): - """ - Generate a response to a prompt and send it to the client in a single batch. - - Args: - prompt (List[int]): The tokenized prompt. - stop_id_sequences (List[List[int]]): A list of stop words passed - to the stopping_criteria function - """ - tokens = [] - finish_reason = "length" - stop_sequence_suffix = None - if self.stream: - self.end_headers() - logging.debug(f"Starting stream:") - else: - logging.debug(f"Starting completion:") - token_logprobs = [] - top_tokens = [] - - prompt = self.get_prompt_cache(prompt) - - text = "" - tic = time.perf_counter() - sampler = make_sampler(self.temperature, top_p=self.top_p) - logits_processors = make_logits_processors( - self.logit_bias, self.repetition_penalty, self.repetition_context_size - ) - for gen_response in stream_generate( - model=self.model, - tokenizer=self.tokenizer, - prompt=prompt, - max_tokens=self.max_tokens, - sampler=sampler, - logits_processors=logits_processors, - prompt_cache=self.prompt_cache.cache, - ): - segment = gen_response.text - text += segment - logging.debug(text) - token = gen_response.token - logprobs = gen_response.logprobs - tokens.append(token) - - if self.logprobs > 0: - sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1) - top_indices = sorted_indices[: self.logprobs] - top_logprobs = logprobs[top_indices] - top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) - top_tokens.append(tuple(top_token_info)) - - token_logprobs.append(logprobs[token].item()) - - stop_condition = stopping_criteria( - tokens, stop_id_sequences, self.tokenizer.eos_token_id - ) - if stop_condition.stop_met: - finish_reason = "stop" - if stop_condition.trim_length: - stop_sequence_suffix = self.tokenizer.decode( - tokens[-stop_condition.trim_length :] - ) - text = text[: -len(stop_sequence_suffix)] - break - - if self.stream: - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - ( - sequence_overlap(tokens, sequence) - for sequence in stop_id_sequences - ) - ): - continue - elif segment: - response = self.generate_response(segment, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - - self.prompt_cache.tokens.extend(tokens) - - logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec") - logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec") - logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB") - - if self.stream: - response = self.generate_response(segment, finish_reason) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response(len(prompt), len(tokens)) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() - else: - response = self.generate_response( - text, - finish_reason, - len(prompt), - len(tokens), - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - ) - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - - # Send an additional Content-Length header when it is known - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - self.wfile.write(response_json) - self.wfile.flush() - - def completion_usage_response( - self, - prompt_token_count: Optional[int] = None, - completion_token_count: Optional[int] = None, - ): - response = { - "id": self.request_id, - "system_fingerprint": self.system_fingerprint, - "object": "chat.completion", - "model": self.requested_model, - "created": self.created, - "choices": [], - "usage": { - "prompt_tokens": prompt_token_count, - "completion_tokens": completion_token_count, - "total_tokens": prompt_token_count + completion_token_count, - }, - } - return response - - def handle_chat_completions(self) -> List[int]: - """ - Handle a chat completion request. - - Returns: - mx.array: A mx.array of the tokenized prompt from the request body - """ - body = self.body - assert "messages" in body, "Request did not contain messages" - - # Determine response type - self.request_id = f"chatcmpl-{uuid.uuid4()}" - self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" - if self.tokenizer.chat_template: - messages = body["messages"] - process_message_content(messages) - prompt = self.tokenizer.apply_chat_template( - messages, - body.get("tools", None), - add_generation_prompt=True, - ) - else: - prompt = convert_chat(body["messages"], body.get("role_mapping")) - prompt = self.tokenizer.encode(prompt) - - return prompt - - def handle_text_completions(self) -> List[int]: - """ - Handle a text completion request. - - Returns: - mx.array: A mx.array of the tokenized prompt from the request body - """ - # Determine response type - self.request_id = f"cmpl-{uuid.uuid4()}" - self.object_type = "text_completion" - assert "prompt" in self.body, "Request did not contain a prompt" - return self.tokenizer.encode(self.body["prompt"]) - - def do_GET(self): - """ - Respond to a GET request from a client. - """ - if self.path == "/v1/models": - self.handle_models_request() - else: - self._set_completion_headers(404) - self.end_headers() - self.wfile.write(b"Not Found") - - def handle_models_request(self): - """ - Handle a GET request for the /v1/models endpoint. - """ - self._set_completion_headers(200) - self.end_headers() - - # Scan the cache directory for downloaded mlx models - hf_cache_info = scan_cache_dir() - downloaded_models = [ - repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id - ] - - # Create a list of available models - models = [ - { - "id": repo.repo_id, - "object": "model", - "created": self.created, - } - for repo in downloaded_models - ] - - response = {"object": "list", "data": models} - - response_json = json.dumps(response).encode() - self.wfile.write(response_json) - self.wfile.flush() - - -def run( - host: str, - port: int, - model_provider: ModelProvider, - server_class=HTTPServer, - handler_class=APIHandler, -): - server_address = (host, port) - prompt_cache = PromptCache() - httpd = server_class( - server_address, - lambda *args, **kwargs: handler_class( - model_provider, - prompt_cache=prompt_cache, - system_fingerprint=get_system_fingerprint(), - *args, - **kwargs, - ), - ) - warnings.warn( - "mlx_lm.server is not recommended for production as " - "it only implements basic security checks." - ) - logging.info(f"Starting httpd at {host} on port {port}...") - httpd.serve_forever() - - -def main(): - parser = argparse.ArgumentParser(description="MLX Http Server.") - parser.add_argument( - "--model", - type=str, - help="The path to the MLX model weights, tokenizer, and config", - ) - parser.add_argument( - "--adapter-path", - type=str, - help="Optional path for the trained adapter weights and config.", - ) - parser.add_argument( - "--host", - type=str, - default="127.0.0.1", - help="Host for the HTTP server (default: 127.0.0.1)", - ) - parser.add_argument( - "--port", - type=int, - default=8080, - help="Port for the HTTP server (default: 8080)", - ) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Enable trusting remote code for tokenizer", - ) - parser.add_argument( - "--log-level", - type=str, - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Set the logging level (default: INFO)", - ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - required=False, - ) - parser.add_argument( - "--chat-template", - type=str, - default="", - help="Specify a chat template for the tokenizer", - required=False, - ) - parser.add_argument( - "--use-default-chat-template", - action="store_true", - help="Use the default chat template", - ) - args = parser.parse_args() - - logging.basicConfig( - level=getattr(logging, args.log_level.upper(), None), - format="%(asctime)s - %(levelname)s - %(message)s", - ) - - if args.cache_limit_gb is not None: - logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB") - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - - run(args.host, args.port, ModelProvider(args)) - - -if __name__ == "__main__": - main() diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py deleted file mode 100644 index b33d504b..00000000 --- a/llms/mlx_lm/tokenizer_utils.py +++ /dev/null @@ -1,376 +0,0 @@ -import json -from functools import partial -from typing import List - -from transformers import AutoTokenizer - - -class StreamingDetokenizer: - """The streaming detokenizer interface so that we can detokenize one token at a time. - - Example usage is as follows: - - detokenizer = ... - - # Reset the tokenizer state - detokenizer.reset() - - for token in generate(...): - detokenizer.add_token(token.item()) - - # Contains the whole text so far. Some tokens may not be included - # since it contains whole words usually. - detokenizer.text - - # Contains the printable segment (usually a word) since the last - # time it was accessed - detokenizer.last_segment - - # Contains all the tokens added so far - detokenizer.tokens - - # Make sure that we detokenize any remaining tokens - detokenizer.finalize() - - # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens) - """ - - __slots__ = ("text", "tokens", "offset") - - def reset(self): - raise NotImplementedError() - - def add_token(self, token): - raise NotImplementedError() - - def finalize(self): - raise NotImplementedError() - - @property - def last_segment(self): - """Return the last segment of readable text since last time this property was accessed.""" - text = self.text - segment = text[self.offset :] - self.offset = len(text) - return segment - - -class NaiveStreamingDetokenizer(StreamingDetokenizer): - """NaiveStreamingDetokenizer relies on the underlying tokenizer - implementation and should work with every tokenizer. - - Its complexity is O(T^2) where T is the longest line since it will - repeatedly detokenize the same tokens until a new line is generated. - """ - - def __init__(self, tokenizer): - self._tokenizer = tokenizer - self._tokenizer.decode([0]) - self.reset() - - def reset(self): - self.offset = 0 - self.tokens = [] - self._text = "" - self._current_tokens = [] - self._current_text = "" - - def add_token(self, token): - self._current_tokens.append(token) - self.tokens.append(token) - - def finalize(self): - self._text += self._tokenizer.decode(self._current_tokens) - self._current_tokens = [] - self._current_text = "" - - @property - def text(self): - if self._current_tokens: - self._current_text = self._tokenizer.decode(self._current_tokens) - if ( - self._tokenizer.clean_up_tokenization_spaces - and self._current_text[-1] == " " - ): - self._current_text = self._current_text[:-1] - if self._current_text and self._current_text[-1] == "\n": - self._text += self._current_text - self._current_tokens.clear() - self._current_text = "" - return self._text + self._current_text - - -class SPMStreamingDetokenizer(StreamingDetokenizer): - """A streaming detokenizer for SPM models. - - It adds tokens to the text if the next token starts with the special SPM - underscore which results in linear complexity. - """ - - def __init__(self, tokenizer, trim_space=True): - self.trim_space = trim_space - self._sep = "\u2581".encode() - - # Extract the tokens in a list from id to text - self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) - for value, tokenid in tokenizer.vocab.items(): - if value.startswith("<0x"): - # Replace bytes with their value - self.tokenmap[tokenid] = bytes([int(value[3:5], 16)]) - else: - self.tokenmap[tokenid] = value.encode() - - self.reset() - - def reset(self): - self.offset = 0 - self._unflushed = b"" - self.text = "" - self.tokens = [] - - def _try_flush(self, force=False): - text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace") - if not force and text.endswith("\ufffd"): - return - if not self.text and self.trim_space and text and text[0] == " ": - text = text[1:] - self.text += text - self._unflushed = b"" - - def add_token(self, token): - self.tokens.append(token) - v = self.tokenmap[token] - self._unflushed += v - self._try_flush() - - def finalize(self): - self._try_flush(force=True) - self._unflushed = b"" - - -class BPEStreamingDetokenizer(StreamingDetokenizer): - """A streaming detokenizer for OpenAI style BPE models. - - It adds tokens to the text if the next token starts with a space similar to - the SPM detokenizer. - """ - - _byte_decoder = None - _space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re") - - def __init__(self, tokenizer): - self.clean_spaces = tokenizer.clean_up_tokenization_spaces - - # Extract the tokens in a list from id to text - self.tokenmap = [None] * len(tokenizer.vocab) - for value, tokenid in tokenizer.vocab.items(): - self.tokenmap[tokenid] = value - - self.reset() - - # Make the BPE byte decoder from - # https://github.com/openai/gpt-2/blob/master/src/encoder.py - self.make_byte_decoder() - - def reset(self): - self.offset = 0 - self._unflushed = "" - self.text = "" - self.tokens = [] - - def _decode_bytes(self, seq): - barr = bytearray() - for c in seq: - res = self._byte_decoder.get(c, False) - if res: - barr.append(res) - else: - barr.extend(bytes(c, "utf-8")) - return barr.decode("utf-8", "replace") - - def _maybe_trim_space(self, current_text): - if len(current_text) == 0: - return current_text - elif current_text[0] != " ": - return current_text - elif not self.text: - return current_text[1:] - elif self.clean_spaces and current_text[1:].startswith(self._space_matches): - return current_text[1:] - return current_text - - def add_token(self, token): - self.tokens.append(token) - v = self.tokenmap[token] - self._unflushed += v - text = self._decode_bytes(self._unflushed) - - # For multi-byte utf-8 wait until they are complete - # For single spaces wait until the next token to clean it if needed - if not text.endswith("\ufffd") and not ( - len(v) == 1 and self._byte_decoder[v[0]] == 32 - ): - self.text += self._maybe_trim_space(text) - self._unflushed = "" - - def finalize(self): - current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( - "utf-8", - "replace", - ) - self.text += self._maybe_trim_space(current_text) - self._unflushed = "" - - @classmethod - def make_byte_decoder(cls): - """See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale.""" - if cls._byte_decoder is not None: - return - - char_to_bytes = {} - limits = [ - 0, - ord("!"), - ord("~") + 1, - ord("¡"), - ord("¬") + 1, - ord("®"), - ord("ÿ") + 1, - ] - n = 0 - for i, (start, stop) in enumerate(zip(limits, limits[1:])): - if i % 2 == 0: - for b in range(start, stop): - char_to_bytes[chr(2**8 + n)] = b - n += 1 - else: - for b in range(start, stop): - char_to_bytes[chr(b)] = b - cls._byte_decoder = char_to_bytes - - -class TokenizerWrapper: - """A wrapper that combines an HF tokenizer and a detokenizer. - - Accessing any attribute other than the ``detokenizer`` is forwarded to the - huggingface tokenizer. - """ - - def __init__( - self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None - ): - self._tokenizer = tokenizer - self._detokenizer = detokenizer_class(tokenizer) - self._eos_token_ids = ( - set(eos_token_ids) - if eos_token_ids is not None - else {tokenizer.eos_token_id} - ) - - def add_eos_token(self, token: str): - token_id = None - try: - token_id = int(token) - except ValueError: - token_id = self._tokenizer.convert_tokens_to_ids(token) - - if token_id is None: - raise ValueError(f"'{token}' is not a token for this tokenizer") - - self._eos_token_ids.add(token_id) - - def __getattr__(self, attr): - if attr == "detokenizer": - return self._detokenizer - elif attr == "eos_token_ids": - return self._eos_token_ids - elif attr.startswith("_"): - return self.__getattribute__(attr) - else: - return getattr(self._tokenizer, attr) - - def __setattr__(self, attr, value): - if attr in {"detokenizer", "eos_token_ids"}: - if attr == "detokenizer": - raise AttributeError("Cannot set the detokenizer.") - elif attr == "eos_token_ids": - self._eos_token_ids = set(value) if value is not None else set() - elif attr.startswith("_"): - super().__setattr__(attr, value) - else: - setattr(self._tokenizer, attr, value) - - -def _match(a, b): - if type(a) != type(b): - return False - if isinstance(a, dict): - return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a) - if isinstance(a, list): - return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b)) - - return a == b - - -def _is_spm_decoder(decoder): - _target_description = { - "type": "Sequence", - "decoders": [ - {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, - {"type": "ByteFallback"}, - {"type": "Fuse"}, - {"type": "Strip", "content": " ", "start": 1, "stop": 0}, - ], - } - return _match(_target_description, decoder) - - -def _is_spm_decoder_no_space(decoder): - _target_description = { - "type": "Sequence", - "decoders": [ - {"type": "Replace", "pattern": {"String": "▁"}, "content": " "}, - {"type": "ByteFallback"}, - {"type": "Fuse"}, - ], - } - return _match(_target_description, decoder) - - -def _is_bpe_decoder(decoder): - return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" - - -def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): - """Load a huggingface tokenizer and try to infer the type of streaming - detokenizer to use. - - Note, to use a fast streaming tokenizer, pass a local file path rather than - a Hugging Face repo ID. - """ - detokenizer_class = NaiveStreamingDetokenizer - - tokenizer_file = model_path / "tokenizer.json" - if tokenizer_file.exists(): - with open(tokenizer_file, "r", encoding="utf-8") as fid: - tokenizer_content = json.load(fid) - if "decoder" in tokenizer_content: - if _is_spm_decoder(tokenizer_content["decoder"]): - detokenizer_class = SPMStreamingDetokenizer - elif _is_spm_decoder_no_space(tokenizer_content["decoder"]): - detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False) - elif _is_bpe_decoder(tokenizer_content["decoder"]): - detokenizer_class = BPEStreamingDetokenizer - - if isinstance(eos_token_ids, int): - eos_token_ids = [eos_token_ids] - return TokenizerWrapper( - AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), - detokenizer_class, - eos_token_ids=eos_token_ids, - ) - - -def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List: - removed_bos = sequence if sequence[0] != bos else sequence[1:] - return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos diff --git a/llms/mlx_lm/tuner/__init__.py b/llms/mlx_lm/tuner/__init__.py deleted file mode 100644 index 2e6d2f90..00000000 --- a/llms/mlx_lm/tuner/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .trainer import TrainingArgs, evaluate, train -from .utils import linear_to_lora_layers diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py deleted file mode 100644 index a6f3bd29..00000000 --- a/llms/mlx_lm/tuner/datasets.py +++ /dev/null @@ -1,273 +0,0 @@ -import itertools -import json -import types -from pathlib import Path -from typing import Any, Dict, List, Optional - -from transformers import PreTrainedTokenizer - - -class Dataset: - """ - Light-weight wrapper to hold a dataset. - """ - - def __init__( - self, - data: List[Dict[str, str]], - tokenizer: PreTrainedTokenizer, - text_key: str = "text", - ): - self._data = [tokenizer.encode(d[text_key]) for d in data] - for d in self._data: - if d[-1] != tokenizer.eos_token_id: - d.append(tokenizer.eos_token_id) - - def __getitem__(self, idx: int): - return self._data[idx] - - def __len__(self): - return len(self._data) - - -class ChatDataset: - """ - A dataset for chat data in the format of {"messages": [...]} - https://platform.openai.com/docs/guides/fine-tuning/example-format - """ - - def __init__( - self, - data: List[Dict[str, str]], - tokenizer: PreTrainedTokenizer, - chat_key: str = "messages", - mask_prompt: bool = False, - ): - self._data = [] - for d in data: - messages = d[chat_key] - tools = d.get("tools", None) - tokens = tokenizer.apply_chat_template(messages, tools=tools) - if mask_prompt: - messages = messages[:-1] - offset = len(tokenizer.apply_chat_template(messages, tools=tools)) - self._data.append((tokens, offset)) - else: - self._data.append(tokens) - - def __getitem__(self, idx: int): - return self._data[idx] - - def __len__(self): - return len(self._data) - - -class CompletionsDataset: - """ - A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} - or using user-provided keys for prompt and completion values - https://platform.openai.com/docs/guides/fine-tuning/example-format - """ - - def __init__( - self, - data: List[Dict[str, str]], - tokenizer: PreTrainedTokenizer, - prompt_key: str, - completion_key: str, - mask_prompt: bool, - ): - self._data = [] - for d in data: - tokens = tokenizer.apply_chat_template( - [ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[completion_key]}, - ], - ) - if mask_prompt: - offset = len( - tokenizer.apply_chat_template( - [{"role": "user", "content": d[prompt_key]}] - ) - ) - self._data.append((tokens, offset)) - else: - self._data.append(tokens) - - def __getitem__(self, idx: int): - return self._data[idx] - - def __len__(self): - return len(self._data) - - -class ConcatenatedDataset: - def __init__(self, data: List[Any]): - self._data = list(itertools.chain(*data)) - - def __getitem__(self, idx: int): - return self._data[idx] - - def __len__(self): - return len(self._data) - - -def create_dataset( - data, - tokenizer: PreTrainedTokenizer, - config, -): - mask_prompt = getattr(config, "mask_prompt", False) - prompt_feature = getattr(config, "prompt_feature", "prompt") - text_feature = getattr(config, "text_feature", "text") - completion_feature = getattr(config, "completion_feature", "completion") - chat_feature = getattr(config, "chat_feature", "messages") - sample = data[0] - if prompt_feature in sample and completion_feature in sample: - return CompletionsDataset( - data, tokenizer, prompt_feature, completion_feature, mask_prompt - ) - elif chat_feature in sample: - return ChatDataset( - data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt - ) - elif text_feature in sample: - if mask_prompt: - raise ValueError("Prompt masking not supported for text dataset.") - return Dataset(data, tokenizer, text_key=text_feature) - else: - raise ValueError( - "Unsupported data format, check the supported formats here:\n" - "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." - ) - - -def load_local_dataset( - data_path: Path, - tokenizer: PreTrainedTokenizer, - config, -): - def load_subset(path): - if not path.exists(): - return [] - with open(path, "r") as fid: - data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, config) - - names = ("train", "valid", "test") - train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] - return train, valid, test - - -def load_hf_dataset( - data_id: str, - tokenizer: PreTrainedTokenizer, - config, -): - from datasets import exceptions, load_dataset - - try: - dataset = load_dataset(data_id) - - names = ("train", "valid", "test") - - train, valid, test = [ - ( - create_dataset(dataset[n], tokenizer, config) - if n in dataset.keys() - else [] - ) - for n in names - ] - - except exceptions.DatasetNotFoundError: - raise ValueError(f"Not found Hugging Face dataset: {data_id} .") - - return train, valid, test - - -def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): - import datasets - - def create_hf_dataset(dataset_name, config, split, hf_config): - ds = datasets.load_dataset( - dataset_name, - split=split, - **hf_config, - ) - return create_dataset(ds, tokenizer, config) - - dataset_collection = args.hf_dataset - if isinstance(dataset_collection, dict): - dataset_collection = [dataset_collection] - - collection = [] - for ds in dataset_collection: - ds_name = ds["name"] - print(f"Loading Hugging Face dataset {ds_name}.") - ds["mask_prompt"] = getattr(args, "mask_prompt", False) - config = types.SimpleNamespace(**ds) - hf_config = ds.get("config", {}) - if args.train: - train_split = ds.get("train_split", "train[:80%]") - valid_split = ds.get("valid_split", "train[-10%:]") - train = create_hf_dataset( - ds_name, - config, - train_split, - hf_config, - ) - valid = create_hf_dataset( - ds_name, - config, - valid_split, - hf_config, - ) - else: - train, valid = [], [] - - if args.test: - test_split = ds.get("test_split") - test = create_hf_dataset( - ds_name, - config, - test_split, - hf_config, - ) - else: - test = [] - - collection.append((train, valid, test)) - - if len(collection) == 1: - return collection[0] - - # Otherwise concatenate them - return tuple(map(ConcatenatedDataset, zip(*collection))) - - -def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", False): - train, valid, test = load_custom_hf_dataset(args, tokenizer) - else: - data_path = Path(args.data) - if data_path.exists(): - train, valid, test = load_local_dataset(data_path, tokenizer, args) - else: - print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args.data, tokenizer, args) - - if args.train and len(train) == 0: - raise ValueError( - "Training set not found or empty. Must provide training set for fine-tuning." - ) - if args.train and len(valid) == 0: - raise ValueError( - "Validation set not found or empty. Must provide validation set for fine-tuning." - ) - if args.test and len(test) == 0: - raise ValueError( - "Test set not found or empty. Must provide test set for evaluation." - ) - return train, valid, test diff --git a/llms/mlx_lm/tuner/dora.py b/llms/mlx_lm/tuner/dora.py deleted file mode 100644 index aba1f6f4..00000000 --- a/llms/mlx_lm/tuner/dora.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import math - -import mlx.core as mx -import mlx.nn as nn - - -class DoRALinear(nn.Module): - @staticmethod - def from_base( - linear: nn.Linear, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - ): - # TODO remove when input_dims and output_dims are attributes - # on linear and quantized linear - output_dims, input_dims = linear.weight.shape - if isinstance(linear, nn.QuantizedLinear): - input_dims *= 32 // linear.bits - dora_lin = DoRALinear( - input_dims=input_dims, - output_dims=output_dims, - r=r, - dropout=dropout, - scale=scale, - ) - dora_lin.set_linear(linear) - return dora_lin - - def fuse(self, de_quantize: bool = False): - linear = self.linear - bias = "bias" in linear - weight = self._dequantized_weight() - - # Use the same type as the linear weight - dtype = weight.dtype - - output_dims, input_dims = weight.shape - fused_linear = nn.Linear(input_dims, output_dims, bias=False) - - lora_b = (self.scale * self.lora_b.T).astype(dtype) - lora_a = self.lora_a.T.astype(dtype) - weight = weight + lora_b @ lora_a - norm_scale = self.m / mx.linalg.norm(weight, axis=1) - fused_linear.weight = norm_scale[:, None] * weight - - if bias: - fused_linear.bias = linear.bias - - if self._is_quantized() and not de_quantize: - fused_linear = nn.QuantizedLinear.from_linear( - fused_linear, - linear.group_size, - linear.bits, - ) - return fused_linear - - def __init__( - self, - input_dims: int, - output_dims: int, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - bias: bool = False, - ): - super().__init__() - - # Regular linear layer weights - self.set_linear(nn.Linear(input_dims, output_dims, bias=bias)) - self.dropout = nn.Dropout(p=dropout) - - # Scale for low-rank update - self.scale = scale - - # Low rank lora weights - scale = 1 / math.sqrt(input_dims) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(input_dims, r), - ) - self.lora_b = mx.zeros(shape=(r, output_dims)) - - def set_linear(self, linear): - """ - Set the self.linear layer and recompute self.m. - """ - self.linear = linear - self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1) - - def _dequantized_weight(self): - """ - Return the weight of linear layer and dequantize it if is quantized - """ - weight = self.linear.weight - if self._is_quantized(): - weight = mx.dequantize( - weight, - self.linear.scales, - self.linear.biases, - self.linear.group_size, - self.linear.bits, - ) - return weight - - def _is_quantized(self): - return isinstance(self.linear, nn.QuantizedLinear) - - def __call__(self, x): - # Regular LoRA (without a bias) - w = self._dequantized_weight() - y = x @ w.T - - z = (self.dropout(x) @ self.lora_a) @ self.lora_b - out = y + (self.scale * z).astype(x.dtype) - - # Compute the norm of the adapted weights - adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T - denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1)) - - # Remove the norm and scale by the learned magnitude - out = (self.m / denom).astype(x.dtype) * out - - if "bias" in self.linear: - out = out + self.linear.bias - return out - - -class DoRAEmbedding(nn.Module): - def from_base( - embedding: nn.Embedding, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - ): - num_embeddings, dims = embedding.weight.shape - - # TODO support quantized weights in DoRALinear - if isinstance(embedding, nn.QuantizedLinear): - raise ValueError("DoRAEmbedding does not yet support quantization.") - dora_embedding = DoRAEmbedding( - num_embeddings=num_embeddings, - dims=dims, - r=r, - dropout=dropout, - scale=scale, - ) - dora_embedding.set_embedding(embedding) - return dora_embedding - - def fuse(self, de_quantize: bool = False): - embedding = self.embedding - weight = embedding.weight - - # Use the same type as the linear weight if not quantized - dtype = weight.dtype - - num_embeddings, dims = weight.shape - fused_embedding = nn.Embedding(num_embeddings, dims) - - lora_a = (self.scale * self.lora_a).astype(dtype) - lora_b = self.lora_b.astype(dtype) - weight = weight + lora_a @ lora_b - norm_scale = self.m / mx.linalg.norm(weight, axis=1) - fused_embedding.weight = norm_scale[:, None] * weight - - return fused_embedding - - def __init__( - self, - num_embeddings: int, - dims: int, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - ): - super().__init__() - - # Regular embedding layer weights - self.set_embedding(nn.Embedding(num_embeddings, dims)) - self.dropout = nn.Dropout(p=dropout) - - # Scale for low-rank update - self.scale = scale - - # Low rank lora weights - scale = 1 / math.sqrt(num_embeddings) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(num_embeddings, r), - ) - self.lora_b = mx.zeros(shape=(r, dims)) - - def set_embedding(self, embedding: nn.Module): - self.embedding = embedding - self.m = mx.linalg.norm(embedding.weight, axis=1) - - def __call__(self, x): - y = self.embedding(x) - z = self.scale * self.lora_a[x] @ self.lora_b - out = y + self.dropout(z).astype(y.dtype) - - # Compute the norm of the adapted weights for the individual embeddings - adapted = y + z - denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=-1)) - - # Remove the norm and scale by the learned magnitude - out = (self.m[x] / denom)[..., None] * out - - return out - - def as_linear(self, x): - y = self.embedding.as_linear(x) - z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T - out = y + (self.scale * z).astype(x.dtype) - - # Compute the norm of the adapted weights - adapted = self.embedding.weight + (self.scale * self.lora_a) @ self.lora_b - denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1)) - - # Remove the norm and scale by the learned magnitude - out = (self.m / denom) * out - - return out diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py deleted file mode 100644 index c788cb73..00000000 --- a/llms/mlx_lm/tuner/lora.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import math - -import mlx.core as mx -import mlx.nn as nn - -from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear - - -class LoRALinear(nn.Module): - @staticmethod - def from_base( - linear: nn.Linear, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - ): - # TODO remove when input_dims and output_dims are attributes - # on linear and quantized linear - output_dims, input_dims = linear.weight.shape - if isinstance(linear, nn.QuantizedLinear): - input_dims *= 32 // linear.bits - lora_lin = LoRALinear( - input_dims=input_dims, - output_dims=output_dims, - r=r, - dropout=dropout, - scale=scale, - ) - lora_lin.linear = linear - return lora_lin - - def fuse(self, de_quantize: bool = False): - linear = self.linear - bias = "bias" in linear - weight = linear.weight - is_quantized = isinstance(linear, nn.QuantizedLinear) - - # Use the same type as the linear weight if not quantized - dtype = weight.dtype - - if is_quantized: - dtype = linear.scales.dtype - weight = mx.dequantize( - weight, - linear.scales, - linear.biases, - linear.group_size, - linear.bits, - ) - output_dims, input_dims = weight.shape - fused_linear = nn.Linear(input_dims, output_dims, bias=bias) - - lora_b = (self.scale * self.lora_b.T).astype(dtype) - lora_a = self.lora_a.T.astype(dtype) - fused_linear.weight = weight + lora_b @ lora_a - if bias: - fused_linear.bias = linear.bias - - if is_quantized and not de_quantize: - fused_linear = nn.QuantizedLinear.from_linear( - fused_linear, - linear.group_size, - linear.bits, - ) - - return fused_linear - - def __init__( - self, - input_dims: int, - output_dims: int, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - bias: bool = False, - ): - super().__init__() - - # Regular linear layer weights - self.linear = nn.Linear(input_dims, output_dims, bias=bias) - - self.dropout = nn.Dropout(p=dropout) - - # Scale for low-rank update - self.scale = scale - - # Low rank lora weights - scale = 1 / math.sqrt(input_dims) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(input_dims, r), - ) - self.lora_b = mx.zeros(shape=(r, output_dims)) - - def __call__(self, x): - y = self.linear(x) - z = (self.dropout(x) @ self.lora_a) @ self.lora_b - return y + (self.scale * z).astype(x.dtype) - - -class LoRASwitchLinear(nn.Module): - @staticmethod - def from_base( - linear: nn.Module, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - ): - lora_lin = LoRASwitchLinear( - input_dims=linear.input_dims, - output_dims=linear.output_dims, - num_experts=linear.num_experts, - r=r, - dropout=dropout, - scale=scale, - ) - lora_lin.linear = linear - return lora_lin - - def fuse(self, de_quantize: bool = False): - linear = self.linear - bias = "bias" in linear - weight = linear.weight - is_quantized = isinstance(linear, QuantizedSwitchLinear) - - # Use the same type as the linear weight if not quantized - dtype = weight.dtype - - if is_quantized: - dtype = mx.float16 - weight = mx.dequantize( - weight, - linear.scales, - linear.biases, - linear.group_size, - linear.bits, - ) - num_experts, output_dims, input_dims = weight.shape - fused_linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias) - - lora_b = (self.scale * self.lora_b).astype(dtype) - lora_a = self.lora_a.reshape(num_experts, -1, input_dims).astype(dtype) - fused_linear.weight = weight + lora_b @ lora_a - if bias: - fused_linear.bias = linear.bias - - if is_quantized and not de_quantize: - fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits) - - return fused_linear - - def __init__( - self, - input_dims: int, - output_dims: int, - num_experts: int, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - bias: bool = False, - ): - super().__init__() - - # Regular linear layer weights - self.linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias) - - self.dropout = nn.Dropout(p=dropout) - - # Scale for low-rank update - self.scale = scale - - # Low rank lora weights - scale = 1 / math.sqrt(input_dims) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(r * num_experts, input_dims), - ) - self.lora_b = mx.zeros(shape=(num_experts, output_dims, r)) - self.num_experts = num_experts - - def __call__(self, x, indices): - shape = x.shape[:-3] + (self.num_experts, -1) - - y = self.linear(x, indices) - z = (self.dropout(x) @ self.lora_a.T).reshape(shape) - z = mx.take_along_axis(z, indices[..., None], axis=-2) - z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1) - - return y + (self.scale * z).astype(x.dtype) - - -class LoRAEmbedding(nn.Module): - @staticmethod - def from_base( - embedding: nn.Embedding, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - ): - num_embeddings, dims = embedding.weight.shape - if isinstance(embedding, nn.QuantizedEmbedding): - dims *= 32 // embedding.bits - lora_embedding = LoRAEmbedding( - num_embeddings=num_embeddings, - dims=dims, - r=r, - dropout=dropout, - scale=scale, - ) - lora_embedding.embedding = embedding - return lora_embedding - - def fuse(self, de_quantize: bool = False): - embedding = self.embedding - weight = embedding.weight - is_quantized = isinstance(embedding, nn.QuantizedEmbedding) - - # Use the same type as the linear weight if not quantized - dtype = weight.dtype - - if is_quantized: - dtype = embedding.scales.dtype - weight = mx.dequantize( - weight, - embedding.scales, - embedding.biases, - embedding.group_size, - embedding.bits, - ) - num_embeddings, dims = weight.shape - fused_embedding = nn.Embedding(num_embeddings, dims) - - lora_a = (self.scale * self.lora_a).astype(dtype) - lora_b = self.lora_b.astype(dtype) - fused_embedding.weight = weight + lora_a @ lora_b - - if is_quantized and not de_quantize: - fused_embedding = nn.QuantizedEmbedding.from_embedding( - fused_embedding, - embedding.group_size, - embedding.bits, - ) - - return fused_embedding - - def __init__( - self, - num_embeddings: int, - dims: int, - r: int = 8, - dropout: float = 0.0, - scale: float = 20.0, - ): - super().__init__() - - # Regular embedding layer - self.embedding = nn.Embedding(num_embeddings, dims) - self.dropout = nn.Dropout(p=dropout) - - # Scale for low-rank update - self.scale = scale - - # Low rank lora weights - scale = 1 / math.sqrt(num_embeddings) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(num_embeddings, r), - ) - self.lora_b = mx.zeros(shape=(r, dims)) - - def __call__(self, x): - y = self.embedding(x) - z = self.dropout(self.lora_a[x] @ self.lora_b) - out = y + (self.scale * z).astype(y.dtype) - return out - - def as_linear(self, x): - y = self.embedding.as_linear(x) - z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T - return y + (self.scale * z).astype(x.dtype) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py deleted file mode 100644 index 64e26af8..00000000 --- a/llms/mlx_lm/tuner/trainer.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import glob -import shutil -import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn -import numpy as np -from mlx.nn.utils import average_gradients -from mlx.utils import tree_flatten -from transformers import PreTrainedTokenizer - -from .datasets import CompletionsDataset - - -def grad_checkpoint(layer): - """ - Update all instances of type(layer) to use gradient checkpointing. - """ - fn = type(layer).__call__ - - def checkpointed_fn(model, *args, **kwargs): - def inner_fn(params, *args, **kwargs): - model.update(params) - return fn(model, *args, **kwargs) - - return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) - - type(layer).__call__ = checkpointed_fn - - -@dataclass -class TrainingArgs: - batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) - iters: int = field(default=100, metadata={"help": "Iterations to train for."}) - val_batches: int = field( - default=25, - metadata={ - "help": "Number of validation batches, -1 uses the entire validation set." - }, - ) - steps_per_report: int = field( - default=10, - metadata={"help": "Number of training steps between loss reporting."}, - ) - steps_per_eval: int = field( - default=200, metadata={"help": "Number of training steps between validations."} - ) - steps_per_save: int = field( - default=100, metadata={"help": "Save the model every number steps"} - ) - max_seq_length: int = field( - default=2048, metadata={"help": "Maximum sequence length."} - ) - adapter_file: str = field( - default="adapters.safetensors", - metadata={"help": "Save/load path for the trained adapter weights."}, - ) - grad_checkpoint: bool = field( - default=False, - metadata={"help": "Use gradient checkpointing to reduce memory use."}, - ) - - -def default_loss(model, batch, lengths): - inputs = batch[:, :-1] - targets = batch[:, 1:] - - logits = model(inputs) - logits = logits.astype(mx.float32) - - steps = mx.arange(1, targets.shape[1] + 1) - mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) - - ce = nn.losses.cross_entropy(logits, targets) * mask - ntoks = mask.sum() - ce = ce.sum() / ntoks - - return ce, ntoks - - -def iterate_batches( - dataset, - tokenizer, - batch_size, - max_seq_length, - train=False, -): - # Sort by length: - idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) - if len(dataset) < batch_size: - raise ValueError( - f"Dataset must have at least batch_size={batch_size}" - f" examples but only has {len(dataset)}." - ) - - # If running in distributed mode (N machines) then each one should skip N-1 - # samples - step = mx.distributed.init().size() - if batch_size % step != 0: - raise ValueError("The batch size must be divisible by the number of workers") - - # Make the batches: - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] - - while True: - indices = np.random.permutation(len(batch_idx)) - for i in indices: - batch = [dataset[j] for j in batch_idx[i]] - if len(batch[0]) == 2: - batch, offsets = zip(*batch) - else: - offsets = [0] * len(batch) - lengths = [len(x) for x in batch] - if max(lengths) > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " - "Consider pre-splitting your data to save memory." - ) - - # Pad to the nearest multiple of 8 or the maximum length - pad_to = 8 - max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - - for j in range(batch_size // step): - truncated_length = min(lengths[j], max_seq_length) - batch_arr[j, :truncated_length] = batch[j][:truncated_length] - lengths[j] = ( - truncated_length # Update lengths to match truncated lengths - ) - batch = mx.array(batch_arr) - yield batch, mx.array(list(zip(offsets, lengths))) - - if not train: - break - - -def evaluate( - model, - dataset, - tokenizer, - batch_size, - num_batches, - max_seq_length=2048, - loss: callable = default_loss, - iterate_batches: callable = iterate_batches, -): - all_losses = mx.array(0.0) - ntokens = mx.array(0) - - index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - - for _, batch in zip( - index_iterator, - iterate_batches( - dataset=dataset, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_length=max_seq_length, - ), - ): - losses, toks = loss(model, *batch) - all_losses += losses * toks - ntokens += toks - mx.eval(all_losses, ntokens) - - all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) - ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) - - return (all_losses / ntokens).item() - - -class TrainingCallback: - - def on_train_loss_report(self, train_info: dict): - """Called to report training loss at specified intervals.""" - pass - - def on_val_loss_report(self, val_info: dict): - """Called to report validation loss at specified intervals or the beginning.""" - pass - - -def train( - model, - tokenizer, - optimizer, - train_dataset, - val_dataset, - args: TrainingArgs = TrainingArgs(), - loss: callable = default_loss, - iterate_batches: callable = iterate_batches, - training_callback: TrainingCallback = None, -): - print(f"Starting training..., iters: {args.iters}") - world = mx.distributed.init() - world_size = world.size() - rank = world.rank() - if world_size > 1: - print(f"Node {rank} of {world_size}") - - if args.grad_checkpoint: - grad_checkpoint(model.layers[0]) - - state = [model.state, optimizer.state] - - def step(batch): - # Forward and backward pass - (lvalue, toks), grad = loss_value_and_grad(model, *batch) - - # All reduce the gradients if running in distributed mode - grad = average_gradients(grad) - - # Model update - optimizer.update(model, grad) - - return lvalue, toks - - loss_value_and_grad = nn.value_and_grad(model, loss) - - losses = 0 - n_tokens = 0 - steps = 0 - trained_tokens = 0 - train_time = 0 - # Main training loop - for it, batch in zip( - range(1, args.iters + 1), - iterate_batches( - dataset=train_dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - max_seq_length=args.max_seq_length, - train=True, - ), - ): - tic = time.perf_counter() - # Report validation loss if needed, the first validation loss - # is always measured before any training. - if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - tic = time.perf_counter() - val_loss = evaluate( - model=model, - dataset=val_dataset, - loss=loss, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.val_batches, - max_seq_length=args.max_seq_length, - iterate_batches=iterate_batches, - ) - val_time = time.perf_counter() - tic - if rank == 0: - print( - f"Iter {it}: " - f"Val loss {val_loss:.3f}, " - f"Val took {val_time:.3f}s", - flush=True, - ) - - if training_callback is not None: - val_info = { - "iteration": it, - "val_loss": val_loss, - "val_time": val_time, - } - training_callback.on_val_loss_report(val_info) - - tic = time.perf_counter() - - lvalue, toks = step(batch) - losses += lvalue - n_tokens += toks - steps += 1 - mx.eval(state, losses, n_tokens) - train_time += time.perf_counter() - tic - - # Report training loss if needed - if it % args.steps_per_report == 0 or it == args.iters: - train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() - train_loss /= steps * mx.distributed.init().size() - n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() - learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / train_time - tokens_sec = float(n_tokens) / train_time - trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 1e9 - if rank == 0: - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB", - flush=True, - ) - - if training_callback is not None: - train_info = { - "iteration": it, - "train_loss": train_loss, - "learning_rate": learning_rate, - "iterations_per_second": it_sec, - "tokens_per_second": tokens_sec, - "trained_tokens": trained_tokens, - "peak_memory": peak_mem, - } - training_callback.on_train_loss_report(train_info) - - losses = 0 - n_tokens = 0 - steps = 0 - train_time = 0 - - # Save adapter weights - if it % args.steps_per_save == 0: - adapter_weights = dict(tree_flatten(model.trainable_parameters())) - mx.save_safetensors(str(args.adapter_file), adapter_weights) - checkpoint = ( - Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" - ) - mx.save_safetensors(str(checkpoint), adapter_weights) - print( - f"Iter {it}: Saved adapter weights to " - f"{args.adapter_file} and {checkpoint}." - ) - - # Save final weights - adapter_weights = dict(tree_flatten(model.trainable_parameters())) - mx.save_safetensors(str(args.adapter_file), adapter_weights) - print(f"Saved final weights to {args.adapter_file}.") diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py deleted file mode 100644 index cc7c6c20..00000000 --- a/llms/mlx_lm/tuner/utils.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright © 2024 Apple Inc. -import json -import types -from pathlib import Path -from typing import Dict - -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as opt -from mlx.utils import tree_flatten, tree_unflatten - -from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear -from .dora import DoRAEmbedding, DoRALinear -from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear - - -def build_schedule(schedule_config: Dict): - """ - Build a learning rate schedule from the given config. - """ - schedule_fn = getattr(opt.schedulers, schedule_config["name"]) - arguments = schedule_config["arguments"] - initial_lr = arguments[0] - bound_schedule_fn = schedule_fn(*arguments) - if warmup_steps := schedule_config.get("warmup", 0): - warmup_init = schedule_config.get("warmup_init", 0.0) - warmup_fn = opt.schedulers.linear_schedule( - warmup_init, initial_lr, warmup_steps - ) - return opt.schedulers.join_schedules( - [warmup_fn, bound_schedule_fn], [warmup_steps + 1] - ) - else: - return bound_schedule_fn - - -def linear_to_lora_layers( - model: nn.Module, - num_layers: int, - config: Dict, - use_dora: bool = False, -): - """ - Convert some of the models linear layers to lora layers. - - Args: - model (nn.Module): The neural network model. - num_layers (int): The number of blocks to convert to lora layers - starting from the last layer. - config (dict): More configuration parameters for LoRA, including the - rank, scale, and optional layer keys. - use_dora (bool): If True, uses DoRA instead of LoRA. - Default: ``False`` - """ - - def to_lora(layer): - if isinstance(layer, (nn.Linear, nn.QuantizedLinear)): - LoRALayer = DoRALinear if use_dora else LoRALinear - elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)): - if use_dora: - raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.") - LoRALayer = LoRASwitchLinear - elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)): - LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding - else: - raise ValueError( - f"Can't convert layer of type {type(layer).__name__} to LoRA" - ) - - return LoRALayer.from_base( - layer, - r=config["rank"], - scale=config["scale"], - dropout=config["dropout"], - ) - - keys = config.get("keys", None) - if keys is not None: - keys = set(keys) - elif model.model_type in [ - "mistral", - "llama", - "phi", - "mixtral", - "nemotron", - "stablelm", - "hunyuan", - "qwen2", - "qwen2_moe", - "phimoe", - "gemma", - "gemma2", - "granite", - "helium", - "starcoder2", - "cohere", - "cohere2", - "minicpm", - "deepseek", - "olmo2", - "olmoe", - "internlm3", - ]: - keys = set(["self_attn.q_proj", "self_attn.v_proj"]) - if model.model_type in ["mixtral", "phimoe"]: - keys.add("block_sparse_moe.gate") - if model.model_type == "qwen2_moe": - keys.add("mlp.gate") - keys.add("mlp.shared_expert_gate") - if model.model_type == "olmoe": - keys.add("mlp.gate") - - elif model.model_type == "gpt_bigcode": - keys = set(["attn.c_attn"]) - elif model.model_type == "gpt2": - keys = set(["attn.c_attn"]) - elif model.model_type == "gpt_neox": - keys = set(["attention.query_key_value"]) - elif model.model_type == "olmo": - keys = set(["att_proj"]) - elif model.model_type == "openelm": - keys = set(["attn.qkv_proj"]) - elif model.model_type == "phi3": - keys = set(["self_attn.qkv_proj"]) - elif model.model_type == "phi-msft": - keys = set(["mixer.Wqkv", "moe.gate"]) - elif model.model_type == "dbrx": - keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"]) - elif model.model_type == "internlm2": - keys = set(["attention.wqkv", "attention.wo"]) - elif model.model_type == "deepseek_v2": - keys = set( - [ - "self_attn.q_proj", - "self_attn.q_a_proj", - "self_attn.q_b_proj", - "self_attn.kv_a_proj_with_mqa", - "self_attn.kv_b_proj", - ] - ) - elif model.model_type == "mamba": - keys = set( - [ - "mixer.in_proj", - "mixer.x_proj", - "mixer.dt_proj", - "mixer.out_proj", - ] - ) - elif model.model_type == "exaone": - keys = set(["attn.attention.q_proj", "attn.attention.v_proj"]) - else: - raise ValueError(f"Lora does not support {model.model_type}") - - for l in model.layers[-max(num_layers, 0) :]: - lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] - if lora_layers: - l.update_modules(tree_unflatten(lora_layers)) - - lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys] - if lora_modules: - model.update_modules(tree_unflatten(lora_modules)) - - -def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: - """ - Load any fine-tuned adapters / layers. - - Args: - model (nn.Module): The neural network model. - adapter_path (str): Path to the adapter configuration file. - - Returns: - nn.Module: The updated model with LoRA layers applied. - """ - adapter_path = Path(adapter_path) - if not adapter_path.exists(): - raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") - with open(adapter_path / "adapter_config.json", "r") as fid: - config = types.SimpleNamespace(**json.load(fid)) - fine_tune_type = getattr(config, "fine_tune_type", "lora") - if fine_tune_type != "full": - linear_to_lora_layers( - model, - config.num_layers, - config.lora_parameters, - use_dora=(fine_tune_type == "dora"), - ) - model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) - return model - - -def dequantize(model: nn.Module) -> nn.Module: - """ - Dequantize the quantized linear layers in the model. - - Args: - model (nn.Module): The model with quantized linear layers. - - Returns: - nn.Module: The model with dequantized layers. - """ - de_quantize_layers = [] - for name, module in model.named_modules(): - if isinstance(module, nn.QuantizedLinear): - bias = "bias" in module - weight = module.weight - weight = mx.dequantize( - weight, - module.scales, - module.biases, - module.group_size, - module.bits, - ).astype(mx.float16) - output_dims, input_dims = weight.shape - linear = nn.Linear(input_dims, output_dims, bias=bias) - linear.weight = weight - if bias: - linear.bias = module.bias - de_quantize_layers.append((name, linear)) - if isinstance(module, nn.QuantizedEmbedding): - weight = mx.dequantize( - module.weight, - module.scales, - module.biases, - module.group_size, - module.bits, - ).astype(mx.float16) - num_embeddings, dims = weight.shape - emb = nn.Embedding(num_embeddings, dims) - emb.weight = weight - de_quantize_layers.append((name, emb)) - - if len(de_quantize_layers) > 0: - model.update_modules(tree_unflatten(de_quantize_layers)) - return model - - -def remove_lora_layers(model: nn.Module) -> nn.Module: - """ - Remove the LoRA layers from the model. - - Args: - model (nn.Module): The model with LoRA layers. - - Returns: - nn.Module: The model without LoRA layers. - """ - reset_layers = [] - for name, module in model.named_modules(): - if isinstance(module, LoRALinear): - reset_layers.append((name, module.linear)) - if len(reset_layers) > 0: - model.update_modules(tree_unflatten(reset_layers)) - return model - - -def nparams(module): - if hasattr(module, "bits"): - n = 0 if not hasattr(module, "bias") else module.bias.size - return n + module.weight.size * 32 // module.bits - return sum(v.size for _, v in tree_flatten(module.parameters())) - - -def print_trainable_parameters(model): - leaf_modules = tree_flatten( - model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) - ) - total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 - trainable_p = ( - sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 - ) - print( - f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " - f"({trainable_p:.3f}M/{total_p:.3f}M)" - ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py deleted file mode 100644 index 05fac92f..00000000 --- a/llms/mlx_lm/utils.py +++ /dev/null @@ -1,1118 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import contextlib -import copy -import functools -import glob -import importlib -import json -import logging -import os -import shutil -import time -from dataclasses import dataclass -from pathlib import Path -from textwrap import dedent -from typing import ( - Any, - Callable, - Dict, - Generator, - List, - NamedTuple, - Optional, - Tuple, - Type, - Union, -) - -import mlx.core as mx -import mlx.nn as nn - -if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": - try: - from modelscope import snapshot_download - except ImportError: - raise ImportError( - "Please run `pip install modelscope` to activate the ModelScope." - ) -else: - from huggingface_hub import snapshot_download - -from mlx.utils import tree_flatten, tree_reduce -from transformers import PreTrainedTokenizer - -# Local imports -from .models import cache -from .sample_utils import make_logits_processors, make_sampler -from .tokenizer_utils import TokenizerWrapper, load_tokenizer -from .tuner.utils import dequantize as dequantize_model -from .tuner.utils import load_adapters, nparams - -# Constants -MODEL_REMAPPING = { - "mistral": "llama", # mistral is compatible with llama - "phi-msft": "phixtral", - "falcon_mamba": "mamba", -} - -MAX_FILE_SIZE_GB = 5 - -# A stream on the default device just for generation -generation_stream = mx.new_stream(mx.default_device()) - - -class ModelNotFoundError(Exception): - def __init__(self, message): - self.message = message - super().__init__(self.message) - - -@dataclass -class GenerationResponse: - """ - The output of :func:`stream_generate`. - - Args: - text (str): The next segment of decoded text. This can be an empty string. - token (int): The next token. - from_draft (bool): Whether the token was generated by the draft model. - logprobs (mx.array): A vector of log probabilities. - prompt_tokens (int): The number of tokens in the prompt. - prompt_tps (float): The prompt processing tokens-per-second. - generation_tokens (int): The number of generated tokens. - generation_tps (float): The tokens-per-second for generation. - peak_memory (float): The peak memory used so far in GB. - finish_reason (str): The reason the response is being sent: "length", "stop" or `None` - """ - - text: str - token: int - logprobs: mx.array - from_draft: bool - prompt_tokens: int - prompt_tps: float - generation_tokens: int - generation_tps: float - peak_memory: float - finish_reason: Optional[str] = None - - -@contextlib.contextmanager -def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): - """ - A context manager to temporarily change the wired limit. - - Note, the wired limit should not be changed during an async eval. If an - async eval could be running pass in the streams to synchronize with prior - to exiting the context manager. - """ - model_bytes = tree_reduce( - lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 - ) - max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] - if model_bytes > 0.9 * max_rec_size: - model_mb = model_bytes // 2**20 - max_rec_mb = max_rec_size // 2**20 - print( - f"[WARNING] Generating with a model that requires {model_mb} MB " - f"which is close to the maximum recommended size of {max_rec_mb} " - "MB. This can be slow. See the documentation for possible work-arounds: " - "https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models" - ) - old_limit = mx.metal.set_wired_limit(max_rec_size) - try: - yield None - finally: - if streams is not None: - for s in streams: - mx.synchronize(s) - else: - mx.synchronize() - mx.metal.set_wired_limit(old_limit) - - -def _get_classes(config: dict): - """ - Retrieve the model and model args classes based on the configuration. - - Args: - config (dict): The model configuration. - - Returns: - A tuple containing the Model class and the ModelArgs class. - """ - model_type = config["model_type"] - model_type = MODEL_REMAPPING.get(model_type, model_type) - try: - arch = importlib.import_module(f"mlx_lm.models.{model_type}") - except ImportError: - msg = f"Model type {model_type} not supported." - logging.error(msg) - raise ValueError(msg) - - return arch.Model, arch.ModelArgs - - -def compute_bits_per_weight(model): - model_bytes = tree_reduce( - lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 - ) - leaf_modules = tree_flatten( - model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) - ) - model_params = sum(nparams(m) for _, m in leaf_modules) - return model_bytes * 8 / model_params - - -def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: - """ - Ensures the model is available locally. If the path does not exist locally, - it is downloaded from the Hugging Face Hub. - - Args: - path_or_hf_repo (str): The local path or Hugging Face repository ID of the model. - revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. - - Returns: - Path: The path to the model. - """ - model_path = Path(path_or_hf_repo) - - if not model_path.exists(): - try: - model_path = Path( - snapshot_download( - path_or_hf_repo, - revision=revision, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - "tiktoken.model", - "*.txt", - "*.jsonl", - ], - ) - ) - except: - raise ModelNotFoundError( - f"Model not found for path or HF repo: {path_or_hf_repo}.\n" - "Please make sure you specified the local path or Hugging Face" - " repo id correctly.\nIf you are trying to access a private or" - " gated Hugging Face repo, make sure you are authenticated:\n" - "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login" - ) from None - return model_path - - -def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): - if ( - kv_bits is not None - and not isinstance(prompt_cache[0], cache.QuantizedKVCache) - and prompt_cache[0].offset > quantized_kv_start - ): - for i in range(len(prompt_cache)): - if isinstance(prompt_cache[i], cache.KVCache): - prompt_cache[i] = prompt_cache[i].to_quantized( - group_size=kv_group_size, bits=kv_bits - ) - - -def generate_step( - prompt: mx.array, - model: nn.Module, - *, - max_tokens: int = 256, - sampler: Optional[Callable[mx.array, mx.array]] = None, - logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, - max_kv_size: Optional[int] = None, - prompt_cache: Optional[Any] = None, - prefill_step_size: int = 512, - kv_bits: Optional[int] = None, - kv_group_size: int = 64, - quantized_kv_start: int = 0, - prompt_progress_callback: Optional[Callable[int, int]] = None, -) -> Generator[Tuple[mx.array, mx.array], None, None]: - """ - A generator producing token ids based on the given prompt from the model. - - Args: - prompt (mx.array): The input prompt. - model (nn.Module): The model to use for generation. - max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite - generator. Default: ``256``. - sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a - token from a vector of log probabilities. Default: ``None``. - logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. - max_kv_size (int, optional): Maximum size of the key-value cache. Old - entries (except the first 4 tokens) will be overwritten. - prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if - provided, the cache will be updated in place. - prefill_step_size (int): Step size for processing the prompt. - kv_bits (int, optional): Number of bits to use for KV cache quantization. - None implies no cache quantization. Default: ``None``. - kv_group_size (int): Group size for KV cache quantization. Default: ``64``. - quantized_kv_start (int): Step to begin using a quantized KV cache. - when ``kv_bits`` is non-None. Default: ``0``. - prompt_prorgress_callback (Callable[int, int]): A call-back which takes the - prompt tokens processed so far and the total number of prompt tokens. - - Yields: - Tuple[mx.array, mx.array]: One token and a vector of log probabilities. - """ - - y = prompt - tokens = None - - # Create the KV cache for generation - if prompt_cache is None: - prompt_cache = cache.make_prompt_cache( - model, - max_kv_size=max_kv_size, - ) - elif len(prompt_cache) != len(model.layers): - raise ValueError("Wrong number of layers in the prompt cache.") - - prompt_progress_callback = prompt_progress_callback or (lambda *_: None) - - quantize_cache_fn = functools.partial( - maybe_quantize_kv_cache, - quantized_kv_start=quantized_kv_start, - kv_group_size=kv_group_size, - kv_bits=kv_bits, - ) - - sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) - - def _step(y): - with mx.stream(generation_stream): - logits = model(y[None], cache=prompt_cache) - logits = logits[:, -1, :] - - if logits_processors: - nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y - - for processor in logits_processors: - logits = processor(tokens, logits) - - quantize_cache_fn(prompt_cache) - - logprobs = logits - mx.logsumexp(logits, keepdims=True) - y = sampler(logprobs) - return y, logprobs.squeeze(0) - - with mx.stream(generation_stream): - total_prompt_tokens = y.size - prompt_processed_tokens = 0 - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - quantize_cache_fn(prompt_cache) - mx.eval([c.state for c in prompt_cache]) - prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) - prompt_processed_tokens += prefill_step_size - y = y[prefill_step_size:] - mx.metal.clear_cache() - - y, logprobs = _step(y) - - mx.async_eval(y, logprobs) - n = 0 - while True: - if n != max_tokens: - next_y, next_logprobs = _step(y) - mx.async_eval(next_y, next_logprobs) - if n == 0: - mx.eval(y) - prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) - if n == max_tokens: - break - yield y.item(), logprobs - if n % 256 == 0: - mx.metal.clear_cache() - y, logprobs = next_y, next_logprobs - n += 1 - - -def speculative_generate_step( - prompt: mx.array, - model: nn.Module, - draft_model: nn.Module, - *, - num_draft_tokens=2, - max_tokens: int = 256, - sampler: Optional[Callable[mx.array, mx.array]] = None, - logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, - prompt_cache: Optional[Any] = None, - prefill_step_size: int = 512, - kv_bits: Optional[int] = None, - kv_group_size: int = 64, - quantized_kv_start: int = 0, -) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: - """ - A generator producing token ids based on the given prompt from the model. - - Args: - prompt (mx.array): The input prompt. - model (nn.Module): The model to use for generation. - draft_model (nn.Module): The draft model for speculative decoding. - num_draft_tokens (int, optional): The number of draft tokens for - speculative decoding. Default: ``2``. - max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite - generator. Default: ``256``. - sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a - token from a vector of log probabilities. Default: ``None``. - logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. - prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if - provided, the cache will be updated in place. The cache must be trimmable. - prefill_step_size (int): Step size for processing the prompt. - kv_bits (int, optional): Number of bits to use for KV cache quantization. - None implies no cache quantization. Default: ``None``. - kv_group_size (int): Group size for KV cache quantization. Default: ``64``. - quantized_kv_start (int): Step to begin using a quantized KV cache. - when ``kv_bits`` is non-None. Default: ``0``. - - Yields: - Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, - and a bool indicating if the token was generated by the draft model - """ - - y = prompt.astype(mx.uint32) - prev_tokens = None - - # Create the KV cache for generation - if prompt_cache is None: - model_cache = cache.make_prompt_cache(model) - draft_cache = cache.make_prompt_cache(draft_model) - elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)): - raise ValueError("Wrong number of layers in the prompt cache.") - else: - model_cache = prompt_cache[: len(model.layers)] - draft_cache = prompt_cache[len(model.layers) :] - - sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) - - quantize_cache_fn = functools.partial( - maybe_quantize_kv_cache, - quantized_kv_start=quantized_kv_start, - kv_group_size=kv_group_size, - kv_bits=kv_bits, - ) - - def _process_and_sample(tokens, logits): - if logits_processors: - for processor in logits_processors: - logits = processor(tokens, logits) - - logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - y = sampler(logprobs) - return y, logprobs - - def _step(model, cache, y, n_predict=1): - with mx.stream(generation_stream): - logits = model(y[None], cache=cache) - logits = logits[:, -n_predict:, :] - - quantize_cache_fn(cache) - if logits_processors: - nonlocal prev_tokens - out_y, out_logprobs = [], [] - if n_predict > 1: - y = y[: -(n_predict - 1)] - for i in range(n_predict): - prev_tokens = ( - mx.concat([prev_tokens, y]) if prev_tokens is not None else y - ) - y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :]) - out_y.append(y) - out_logprobs.append(logprobs) - return mx.concatenate(out_y, axis=0), mx.concatenate( - out_logprobs, axis=0 - ) - else: - return _process_and_sample(None, logits.squeeze(0)) - - def _prefill(model, cache, y): - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=cache) - quantize_cache_fn(cache) - mx.eval([c.state for c in cache]) - y = y[prefill_step_size:] - mx.metal.clear_cache() - return y - - def _rewind_cache(num_draft, num_accept): - cache.trim_prompt_cache(model_cache, num_draft - num_accept) - cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0)) - - def _draft_generate(y, num_draft): - if num_draft == 0: - return mx.array([], mx.uint32) - ys = [] - for _ in range(num_draft): - y, _ = _step(draft_model, draft_cache, y) - mx.async_eval(y) - ys.append(y) - return mx.concatenate(ys) - - with mx.stream(generation_stream): - draft_y = _prefill(draft_model, draft_cache, y) - y = _prefill(model, model_cache, y) - - ntoks = 0 - # Set these so the finally block doesn't raise - num_draft = 0 - n = 0 - try: - while True: - num_draft = min(max_tokens - ntoks, num_draft_tokens) - draft_tokens = _draft_generate(draft_y, num_draft) - if prev_tokens is not None: - prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1] - y = mx.concatenate([y, draft_tokens]) - tokens, logprobs = _step(model, model_cache, y, num_draft + 1) - mx.eval(tokens, draft_tokens) - draft_tokens = draft_tokens.tolist() - tokens = tokens.tolist() - n = 0 - while n < num_draft: - tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] - if tn != dtn: - break - n += 1 - ntoks += 1 - yield tn, lpn, True - if ntoks == max_tokens: - break - if ntoks < max_tokens: - ntoks += 1 - yield tokens[n], logprobs[n], False - - if ntoks == max_tokens: - break - - y = mx.array([tokens[n]], mx.uint32) - draft_y = y - - # If we accepted all the draft tokens, include the last - # draft token in the next draft step since it hasn't been - # processed yet by the draft model - if n == num_draft: - draft_y = mx.concatenate( - [mx.array(draft_tokens[-1:], mx.uint32), draft_y] - ) - - if prev_tokens is not None: - prev_tokens = prev_tokens[: -max(num_draft - n, 1)] - _rewind_cache(num_draft, n) - finally: - _rewind_cache(num_draft, n) - - -def stream_generate( - model: nn.Module, - tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: Union[str, mx.array, List[int]], - draft_model: Optional[nn.Module] = None, - **kwargs, -) -> Generator[GenerationResponse, None, None]: - """ - A generator producing text based on the given prompt from the model. - - Args: - model (nn.Module): The model to use for generation. - tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, mx.array, List[int]]): The input prompt string or - integer tokens. - draft_model (Optional[nn.Module]): An optional draft model. If provided - then speculative decoding is used. The draft model must use the same - tokenizer as the main model. Default: ``None``. - kwargs: The remaining options get passed to :func:`generate_step`. - See :func:`generate_step` for more details. - - Yields: - GenerationResponse: An instance containing the generated text segment and - associated metadata. See :class:`GenerationResponse` for details. - """ - if not isinstance(tokenizer, TokenizerWrapper): - tokenizer = TokenizerWrapper(tokenizer) - - if not isinstance(prompt, mx.array): - if isinstance(prompt, str): - # Try to infer if special tokens are needed - add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( - tokenizer.bos_token - ) - prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) - prompt = mx.array(prompt) - - detokenizer = tokenizer.detokenizer - - if draft_model is None: - kwargs.pop("num_draft_tokens", None) - token_generator = generate_step(prompt, model, **kwargs) - # from_draft always false for non-speculative generation - token_generator = ( - (token, logprobs, False) for token, logprobs in token_generator - ) - else: - kwargs.pop("max_kv_size", None) - token_generator = speculative_generate_step( - prompt, model, draft_model, **kwargs - ) - with wired_limit(model, [generation_stream]): - detokenizer.reset() - tic = time.perf_counter() - for n, (token, logprobs, from_draft) in enumerate(token_generator): - if n == 0: - prompt_time = time.perf_counter() - tic - prompt_tps = prompt.size / prompt_time - tic = time.perf_counter() - if token in tokenizer.eos_token_ids: - break - - detokenizer.add_token(token) - - yield GenerationResponse( - text=detokenizer.last_segment, - token=token, - logprobs=logprobs, - from_draft=from_draft, - prompt_tokens=prompt.size, - prompt_tps=prompt_tps, - generation_tokens=n + 1, - generation_tps=(n + 1) / (time.perf_counter() - tic), - peak_memory=mx.metal.get_peak_memory() / 1e9, - finish_reason=None, - ) - - detokenizer.finalize() - yield GenerationResponse( - text=detokenizer.last_segment, - token=token, - logprobs=logprobs, - from_draft=from_draft, - prompt_tokens=prompt.size, - prompt_tps=prompt_tps, - generation_tokens=n + 1, - generation_tps=(n + 1) / (time.perf_counter() - tic), - peak_memory=mx.metal.get_peak_memory() / 1e9, - finish_reason="stop" if token in tokenizer.eos_token_ids else "length", - ) - - -def generate( - model: nn.Module, - tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: Union[str, List[int]], - verbose: bool = False, - formatter: Optional[Callable] = None, - **kwargs, -) -> str: - """ - Generate a complete response from the model. - - Args: - model (nn.Module): The language model. - tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, List[int]]): The input prompt string or integer tokens. - verbose (bool): If ``True``, print tokens and timing information. - Default: ``False``. - kwargs: The remaining options get passed to :func:`stream_generate`. - See :func:`stream_generate` for more details. - """ - if formatter is not None: - print( - "[Warning] Text formatting is deprecated and no longer used. " - "The argument will be removed in a future version." - ) - if verbose: - print("=" * 10) - - text = "" - for response in stream_generate(model, tokenizer, prompt, **kwargs): - if verbose: - print(response.text, end="", flush=True) - text += response.text - - if verbose: - print() - print("=" * 10) - if len(text) == 0: - print("No text generated for this prompt") - return - print( - f"Prompt: {response.prompt_tokens} tokens, " - f"{response.prompt_tps:.3f} tokens-per-sec" - ) - print( - f"Generation: {response.generation_tokens} tokens, " - f"{response.generation_tps:.3f} tokens-per-sec" - ) - print(f"Peak memory: {response.peak_memory:.3f} GB") - return text - - -def load_config(model_path: Path) -> dict: - try: - with open(model_path / "config.json", "r") as f: - config = json.load(f) - except FileNotFoundError: - logging.error(f"Config file not found in {model_path}") - raise - return config - - -def load_model( - model_path: Path, - lazy: bool = False, - strict: bool = True, - model_config: dict = {}, - get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, -) -> nn.Module: - """ - Load and initialize the model from a given path. - - Args: - model_path (Path): The path to load the model from. - lazy (bool): If False eval the model parameters to make sure they are - loaded in memory before returning, otherwise they will be loaded - when needed. Default: ``False`` - strict (bool): Whether or not to raise an exception if weights don't - match. Default: ``True`` - model_config (dict, optional): Optional configuration parameters for the - model. Defaults to an empty dictionary. - get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): - A function that returns the model class and model args class given a config. - Defaults to the ``_get_classes`` function. - - Returns: - nn.Module: The loaded and initialized model. - - Raises: - FileNotFoundError: If the weight files (.safetensors) are not found. - ValueError: If the model class or args class are not found or cannot be instantiated. - """ - config = load_config(model_path) - config.update(model_config) - - weight_files = glob.glob(str(model_path / "model*.safetensors")) - - if not weight_files: - # Try weight for back-compat - weight_files = glob.glob(str(model_path / "weight*.safetensors")) - - if not weight_files and strict: - logging.error(f"No safetensors found in {model_path}") - raise FileNotFoundError(f"No safetensors found in {model_path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - model_class, model_args_class = get_model_classes(config=config) - - model_args = model_args_class.from_dict(config) - model = model_class(model_args) - - if hasattr(model, "sanitize"): - weights = model.sanitize(weights) - - if (quantization := config.get("quantization", None)) is not None: - - def class_predicate(p, m): - # Handle custom per layer quantizations - if p in config["quantization"]: - return config["quantization"][p] - if not hasattr(m, "to_quantized"): - return False - # Handle legacy models which may not have everything quantized - return f"{p}.scales" in weights - - nn.quantize( - model, - group_size=quantization["group_size"], - bits=quantization["bits"], - class_predicate=class_predicate, - ) - - model.load_weights(list(weights.items()), strict=strict) - - if not lazy: - mx.eval(model.parameters()) - - model.eval() - return model, config - - -def load( - path_or_hf_repo: str, - tokenizer_config={}, - model_config={}, - adapter_path: Optional[str] = None, - lazy: bool = False, -) -> Tuple[nn.Module, TokenizerWrapper]: - """ - Load the model and tokenizer from a given path or a huggingface repository. - - Args: - path_or_hf_repo (Path): The path or the huggingface repository to load the model from. - tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. - Defaults to an empty dictionary. - model_config(dict, optional): Configuration parameters specifically for the model. - Defaults to an empty dictionary. - adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers - to the model. Default: ``None``. - lazy (bool): If ``False`` eval the model parameters to make sure they are - loaded in memory before returning, otherwise they will be loaded - when needed. Default: ``False`` - Returns: - Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. - - Raises: - FileNotFoundError: If config file or safetensors are not found. - ValueError: If model class or args class are not found. - """ - model_path = get_model_path(path_or_hf_repo) - - model, config = load_model(model_path, lazy) - if adapter_path is not None: - model = load_adapters(model, adapter_path) - model.eval() - tokenizer = load_tokenizer( - model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None) - ) - - return model, tokenizer - - -def fetch_from_hub( - model_path: Path, lazy: bool = False -) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - model, config = load_model(model_path, lazy) - tokenizer = load_tokenizer( - model_path, eos_token_ids=config.get("eos_token_id", None) - ) - return model, config, tokenizer - - -def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: - """ - Splits the weights into smaller shards. - - Args: - weights (dict): Model weights. - max_file_size_gb (int): Maximum size of each shard in gigabytes. - - Returns: - list: List of weight shards. - """ - max_file_size_bytes = max_file_size_gb << 30 - shards = [] - shard, shard_size = {}, 0 - for k, v in weights.items(): - if shard_size + v.nbytes > max_file_size_bytes: - shards.append(shard) - shard, shard_size = {}, 0 - shard[k] = v - shard_size += v.nbytes - shards.append(shard) - return shards - - -def upload_to_hub(path: str, upload_repo: str, hf_path: str): - """ - Uploads the model to Hugging Face hub. - - Args: - path (str): Local path to the model. - upload_repo (str): Name of the HF repo to upload to. - hf_path (str): Path to the original Hugging Face model. - """ - import os - - from huggingface_hub import HfApi, ModelCard, logging - - from . import __version__ - - card = ModelCard.load(hf_path) - card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"] - card.data.base_model = hf_path - card.text = dedent( - f""" - # {upload_repo} - - The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was - converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) - using mlx-lm version **{__version__}**. - - ## Use with mlx - - ```bash - pip install mlx-lm - ``` - - ```python - from mlx_lm import load, generate - - model, tokenizer = load("{upload_repo}") - - prompt = "hello" - - if tokenizer.chat_template is not None: - messages = [{{"role": "user", "content": prompt}}] - prompt = tokenizer.apply_chat_template( - messages, add_generation_prompt=True - ) - - response = generate(model, tokenizer, prompt=prompt, verbose=True) - ``` - """ - ) - card.save(os.path.join(path, "README.md")) - - logging.set_verbosity_info() - - api = HfApi() - api.create_repo(repo_id=upload_repo, exist_ok=True) - api.upload_large_folder( - folder_path=path, - repo_id=upload_repo, - repo_type="model", - ) - print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") - - -def save_weights( - save_path: Union[str, Path], - weights: Dict[str, Any], - *, - donate_weights: bool = False, -) -> None: - """Save model weights into specified directory.""" - if isinstance(save_path, str): - save_path = Path(save_path) - save_path.mkdir(parents=True, exist_ok=True) - - shards = make_shards(weights) - shards_count = len(shards) - shard_file_format = ( - "model-{:05d}-of-{:05d}.safetensors" - if shards_count > 1 - else "model.safetensors" - ) - - total_size = sum(v.nbytes for v in weights.values()) - index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} - - # Write the weights and make sure no references are kept other than the - # necessary ones - if donate_weights: - weights.clear() - del weights - - for i in range(len(shards)): - shard = shards[i] - shards[i] = None - shard_name = shard_file_format.format(i + 1, shards_count) - shard_path = save_path / shard_name - - mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"}) - - for weight_name in shard.keys(): - index_data["weight_map"][weight_name] = shard_name - del shard - - index_data["weight_map"] = { - k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) - } - - with open(save_path / "model.safetensors.index.json", "w") as f: - json.dump( - index_data, - f, - indent=4, - ) - - -def quantize_model( - model: nn.Module, - config: dict, - q_group_size: int, - q_bits: int, - quant_predicate: Optional[ - Callable[[str, nn.Module, dict], Union[bool, dict]] - ] = None, -) -> Tuple: - """ - Applies quantization to the model weights. - - Args: - model (nn.Module): The model to be quantized. - config (dict): Model configuration. - q_group_size (int): Group size for quantization. - q_bits (int): Bits per weight for quantization. - quant_predicate (Callable): A callable that decides how - to quantize each layer based on the path. - Accepts the layer `path`, the `module` and the model `config`. - Returns either a bool to signify quantize/no quantize or - a dict of quantization parameters to pass to `to_quantized`. - - Returns: - Tuple: Tuple containing quantized weights and config. - """ - quantized_config = copy.deepcopy(config) - quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} - - # Add any custom quantization parameters to the config as we go - def _class_predicate(p, m): - bool_or_params = quant_predicate(p, m, config) - quantized_config["quantization"][p] = bool_or_params - return bool_or_params - - nn.quantize( - model, - q_group_size, - q_bits, - class_predicate=_class_predicate if quant_predicate else None, - ) - # support hf model tree #957 - quantized_config["quantization_config"] = quantized_config["quantization"] - quantized_weights = dict(tree_flatten(model.parameters())) - - bpw = compute_bits_per_weight(model) - print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") - - return quantized_weights, quantized_config - - -def save_config( - config: dict, - config_path: Union[str, Path], -) -> None: - """Save the model configuration to the ``config_path``. - - The final configuration will be sorted before saving for better readability. - - Args: - config (dict): The model configuration. - config_path (Union[str, Path]): Model configuration file path. - """ - # Clean unused keys - config.pop("_name_or_path", None) - - # sort the config for better readability - config = dict(sorted(config.items())) - - # write the updated config to the config_path (if provided) - with open(config_path, "w") as fid: - json.dump(config, fid, indent=4) - - -def mixed_quant_predicate_builder( - low_bits: int = 4, high_bits: int = 4, group_size: int = 64 -) -> Callable[[str, nn.Module, dict], Union[bool, dict]]: - def mixed_quant_predicate( - path: str, - module: nn.Module, - config: dict, - ) -> Union[bool, dict]: - """Implements mixed quantization predicates with similar choices to, for example, llama.cpp's Q4_K_M. - Ref: https://github.com/ggerganov/llama.cpp/blob/917786f43d0f29b7c77a0c56767c0fa4df68b1c5/src/llama.cpp#L5265 - By Alex Barron: https://gist.github.com/barronalex/84addb8078be21969f1690c1454855f3 - """ - - if not hasattr(module, "to_quantized"): - return False - - index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0 - - num_layers = config["num_hidden_layers"] - use_more_bits = ( - index < num_layers // 8 - or index >= 7 * num_layers // 8 - or (index - num_layers // 8) % 3 == 2 - ) - if "v_proj" in path and use_more_bits: - return {"group_size": group_size, "bits": high_bits} - if "down_proj" in path and use_more_bits: - return {"group_size": group_size, "bits": high_bits} - if "lm_head" in path: - return {"group_size": group_size, "bits": high_bits} - - return {"group_size": group_size, "bits": low_bits} - - return mixed_quant_predicate - - -mixed_3_6 = mixed_quant_predicate_builder(low_bits=3) -mixed_2_6 = mixed_quant_predicate_builder(low_bits=2) - - -def convert( - hf_path: str, - mlx_path: str = "mlx_model", - quantize: bool = False, - q_group_size: int = 64, - q_bits: int = 4, - dtype: str = "float16", - upload_repo: str = None, - revision: Optional[str] = None, - dequantize: bool = False, - quant_predicate: Optional[ - Callable[[str, nn.Module, dict], Union[bool, dict]] - ] = None, -): - # Check the save path is empty - if isinstance(mlx_path, str): - mlx_path = Path(mlx_path) - - if mlx_path.exists(): - raise ValueError( - f"Cannot save to the path {mlx_path} as it already exists." - " Please delete the file/directory or specify a new path to save to." - ) - - print("[INFO] Loading") - model_path = get_model_path(hf_path, revision=revision) - model, config, tokenizer = fetch_from_hub(model_path, lazy=True) - - weights = dict(tree_flatten(model.parameters())) - dtype = getattr(mx, dtype) - weights = {k: v.astype(dtype) for k, v in weights.items()} - - if quantize and dequantize: - raise ValueError("Choose either quantize or dequantize, not both.") - - if quantize: - print("[INFO] Quantizing") - model.load_weights(list(weights.items())) - weights, config = quantize_model( - model, config, q_group_size, q_bits, quant_predicate=quant_predicate - ) - - if dequantize: - print("[INFO] Dequantizing") - model = dequantize_model(model) - weights = dict(tree_flatten(model.parameters())) - - del model - save_weights(mlx_path, weights, donate_weights=True) - - py_files = glob.glob(str(model_path / "*.py")) - for file in py_files: - shutil.copy(file, mlx_path) - - tokenizer.save_pretrained(mlx_path) - - save_config(config, config_path=mlx_path / "config.json") - - if upload_repo is not None: - upload_to_hub(mlx_path, upload_repo, hf_path) diff --git a/llms/setup.py b/llms/setup.py deleted file mode 100644 index e6fddbae..00000000 --- a/llms/setup.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import sys -from pathlib import Path - -from setuptools import setup - -package_dir = Path(__file__).parent / "mlx_lm" -with open(package_dir / "requirements.txt") as fid: - requirements = [l.strip() for l in fid.readlines()] - -sys.path.append(str(package_dir)) -from _version import __version__ - -setup( - name="mlx-lm", - version=__version__, - description="LLMs on Apple silicon with MLX and the Hugging Face Hub", - long_description=open("README.md", encoding="utf-8").read(), - long_description_content_type="text/markdown", - readme="README.md", - author_email="mlx@group.apple.com", - author="MLX Contributors", - url="https://github.com/ml-explore/mlx-examples", - license="MIT", - install_requires=requirements, - packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], - python_requires=">=3.8", - extras_require={ - "test": ["datasets"], - "evaluate": ["lm-eval", "tqdm"], - }, - entry_points={ - "console_scripts": [ - "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", - "mlx_lm.chat = mlx_lm.chat:main", - "mlx_lm.convert = mlx_lm.convert:main", - "mlx_lm.evaluate = mlx_lm.evaluate:main", - "mlx_lm.fuse = mlx_lm.fuse:main", - "mlx_lm.generate = mlx_lm.generate:main", - "mlx_lm.lora = mlx_lm.lora:main", - "mlx_lm.merge = mlx_lm.merge:main", - "mlx_lm.server = mlx_lm.server:main", - "mlx_lm.manage = mlx_lm.manage:main", - ] - }, -) diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py deleted file mode 100644 index 5edab8bf..00000000 --- a/llms/tests/test_datsets.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import json -import os -import tempfile -import types -import unittest - -from mlx_lm.tuner import datasets -from transformers import AutoTokenizer - -HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - - -class TestDatasets(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.test_dir_fid = tempfile.TemporaryDirectory() - cls.test_dir = cls.test_dir_fid.name - if not os.path.isdir(cls.test_dir): - os.mkdir(cls.test_dir_fid.name) - - @classmethod - def tearDownClass(cls): - cls.test_dir_fid.cleanup() - - def save_data(self, data): - for ds in ["train", "valid"]: - with open(os.path.join(self.test_dir, f"{ds}.jsonl"), "w") as fid: - for l in data: - json.dump(l, fid) - fid.write("\n") - - def test_text(self): - data = {"text": "This is an example for the model."} - self.save_data(4 * [data]) - args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) - tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) - train, valid, test = datasets.load_dataset(args, tokenizer) - self.assertEqual(len(train), 4) - self.assertEqual(len(valid), 4) - self.assertEqual(len(test), 0) - self.assertTrue(len(train[0]) > 0) - self.assertTrue(len(valid[0]) > 0) - self.assertTrue(isinstance(train, datasets.Dataset)) - - def test_completions(self): - data = {"prompt": "What is the capital of France?", "completion": "Paris."} - self.save_data(4 * [data]) - args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) - tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) - train, valid, test = datasets.load_dataset(args, tokenizer) - self.assertEqual(len(train), 4) - self.assertEqual(len(valid), 4) - self.assertEqual(len(test), 0) - self.assertTrue(len(train[0]) > 0) - self.assertTrue(len(valid[0]) > 0) - self.assertTrue(isinstance(train, datasets.CompletionsDataset)) - - def test_chat(self): - data = { - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello."}, - {"role": "assistant", "content": "How can I assistant you today."}, - ] - } - self.save_data(4 * [data]) - args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) - tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) - train, valid, test = datasets.load_dataset(args, tokenizer) - self.assertEqual(len(train), 4) - self.assertEqual(len(valid), 4) - self.assertEqual(len(test), 0) - self.assertTrue(len(train[0]) > 0) - self.assertTrue(len(valid[0]) > 0) - self.assertTrue(isinstance(train, datasets.ChatDataset)) - - def test_hf(self): - hf_args = { - "name": "billsum", - "prompt_feature": "text", - "completion_feature": "summary", - "train_split": "train[:2%]", - "valid_split": "train[-2%:]", - } - args = types.SimpleNamespace( - hf_dataset=hf_args, - test=False, - train=True, - ) - tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) - train, valid, test = datasets.load_dataset(args, tokenizer) - self.assertTrue(len(train) > 0) - self.assertTrue(len(train[0]) > 0) - self.assertTrue(len(valid) > 0) - self.assertTrue(len(valid[0]) > 0) - self.assertEqual(len(test), 0) - - args = types.SimpleNamespace( - hf_dataset=[hf_args, hf_args], - test=False, - train=True, - ) - train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer) - self.assertEqual(2 * len(train), len(train_double)) - self.assertEqual(2 * len(valid), len(valid_double)) - self.assertEqual(2 * len(test), len(test_double)) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py deleted file mode 100644 index a6d53747..00000000 --- a/llms/tests/test_finetune.py +++ /dev/null @@ -1,447 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import math -import sys -import unittest -from contextlib import contextmanager -from io import StringIO -from unittest.mock import MagicMock - -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as opt -from mlx.utils import tree_flatten -from mlx_lm import lora, tuner -from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear -from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear -from mlx_lm.tuner.trainer import evaluate -from mlx_lm.tuner.utils import build_schedule - - -@contextmanager -def swapped_with_identity(obj, func): - old_func = getattr(obj, func) - setattr(obj, func, lambda x, **kwargs: x) - yield - setattr(obj, func, old_func) - - -class TestLora(unittest.TestCase): - def setUp(self): - self.capturedOutput = StringIO() - sys.stdout = self.capturedOutput - - def tearDown(self): - sys.stdout = sys.__stdout__ - - def test_llama(self): - from mlx_lm.models import llama - - args = llama.ModelArgs( - model_type="llama", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - tie_word_embeddings=False, - ) - - lora_layers = 4 - - def check_config(params, expected_trainable_parameters=None): - n_keys = 2 - if "keys" in params: - n_keys = len(params["keys"]) - model = llama.Model(args) - model.freeze() - tuner.utils.linear_to_lora_layers(model, lora_layers, params) - trainable_params = sum( - v.size for _, v in tree_flatten(model.trainable_parameters()) - ) - - expected_trainable_parameters = expected_trainable_parameters or ( - lora_layers * params["rank"] * args.hidden_size * 2 * n_keys - ) - self.assertEqual(trainable_params, expected_trainable_parameters) - - params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} - check_config(params) - - params["rank"] = 1 - check_config(params) - - params["keys"] = ["self_attn.k_proj"] - check_config(params) - - params["keys"] = ["lm_head"] - check_config( - params, - expected_trainable_parameters=( - params["rank"] * (args.hidden_size + args.vocab_size) - ), - ) - - params["keys"] = ["model.embed_tokens"] - check_config( - params, - expected_trainable_parameters=( - params["rank"] * (args.hidden_size + args.vocab_size) - ), - ) - - def test_gpt_neox(self): - from mlx_lm.models import gpt_neox - - args = gpt_neox.ModelArgs( - model_type="gpt_neox", - max_position_embeddings=2048, - hidden_size=6144, - num_attention_heads=64, - num_hidden_layers=44, - layer_norm_eps=1e-5, - vocab_size=50432, - rotary_emb_base=10_000, - rotary_pct=0.25, - ) - - num_lora_layers = 4 - params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} - - model = gpt_neox.Model(args) - model.freeze() - tuner.utils.linear_to_lora_layers(model, num_lora_layers, params) - - def test_lora_embedding(self): - num_embeddings = 256 - dims = 512 - tokens = mx.array([1, 2, 3]) - - embedding = nn.QuantizedEmbedding(num_embeddings, dims) - dequantized_weight = mx.dequantize( - embedding.weight, - embedding.scales, - embedding.biases, - embedding.group_size, - embedding.bits, - ) - lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10) - new_embedding = lora_emb.fuse(de_quantize=True) - self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight)) - self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens))) - - # as_linear - attn_output = mx.random.uniform(shape=(dims,)) - embedding_lin_out = lora_emb.as_linear(attn_output) - self.assertEqual(embedding_lin_out.shape, (num_embeddings,)) - self.assertTrue( - mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output)) - ) - - # change the value of lora_b and the embeddings will no longer be equal - lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape) - new_embedding = lora_emb.fuse(de_quantize=True) - self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight)) - self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens))) - - -class TestDora(unittest.TestCase): - def test_dora_embedding(self): - num_embeddings = 256 - dims = 512 - tokens = mx.array([1, 2, 3]) - - embedding = nn.Embedding(num_embeddings, dims) - - dora_emb = DoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10) - new_embedding = dora_emb.fuse() - self.assertTrue(mx.array_equal(embedding.weight, new_embedding.weight)) - self.assertTrue(mx.array_equal(embedding(tokens), dora_emb(tokens))) - - # as_linear - attn_output = mx.random.uniform(shape=(dims,)) - embedding_lin_out = dora_emb.as_linear(attn_output) - self.assertEqual(embedding_lin_out.shape, (num_embeddings,)) - self.assertTrue( - mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output)) - ) - - # change the value of lora_b and the embeddings will no longer be equal - dora_emb.lora_b = mx.random.uniform(shape=dora_emb.lora_b.shape) - new_embedding = dora_emb.fuse() - self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight)) - self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens))) - - def test_llama(self): - from mlx_lm.models import llama - - hidden_size = 1024 - intermediate_size = 2048 - args = llama.ModelArgs( - model_type="llama", - hidden_size=hidden_size, - num_hidden_layers=4, - intermediate_size=intermediate_size, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - ) - - dora_layers = 4 - - def check_config(params): - n_keys = 2 - if "keys" in params: - n_keys = len(params["keys"]) - model = llama.Model(args) - model.freeze() - tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True) - trainable_params = sum( - v.size for _, v in tree_flatten(model.trainable_parameters()) - ) - self.assertEqual( - trainable_params, - dora_layers - * (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size), - ) - - params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} - check_config(params) - - params["rank"] = 1 - check_config(params) - - params["keys"] = ["self_attn.k_proj"] - check_config(params) - - def test_dora_m_parameter(self): - dora_lin = DoRALinear(input_dims=100, output_dims=100) - self.assertTrue( - mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1)) - ) - - # Recomputes m when changing Linear - inital_m = dora_lin.m - lin = nn.Linear(10, 10) - dora_lin.set_linear(lin) - self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1))) - - # Works with quantized weights - quantized_linear = nn.QuantizedLinear(512, 512) - dora_lin.set_linear(quantized_linear) - dequantized_weight = mx.dequantize( - quantized_linear.weight, - quantized_linear.scales, - quantized_linear.biases, - quantized_linear.group_size, - quantized_linear.bits, - ) - self.assertTrue( - mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1)) - ) - - def test_dora_from_linear(self): - in_dims = 256 - out_dims = 256 - r = 4 - - linear = nn.Linear(in_dims, out_dims) - dora_lin = DoRALinear.from_base(linear, r) - self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1))) - self.assertEqual(dora_lin.lora_a.shape, (in_dims, r)) - self.assertEqual(dora_lin.lora_b.shape, (r, out_dims)) - self.assertEqual(dora_lin.m.shape, (out_dims,)) - - quantized_linear = nn.QuantizedLinear(in_dims, out_dims) - dequantized_weight = mx.dequantize( - quantized_linear.weight, - quantized_linear.scales, - quantized_linear.biases, - quantized_linear.group_size, - quantized_linear.bits, - ) - dora_quant_lin = DoRALinear.from_base(quantized_linear, r) - self.assertTrue( - mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1)) - ) - self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r)) - self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims)) - self.assertEqual(dora_quant_lin.m.shape, (out_dims,)) - - def test_dora_to_linear(self): - in_dims = 256 - out_dims = 256 - r = 4 - - linear = nn.Linear(in_dims, out_dims, bias=True) - dora_lin = DoRALinear.from_base(linear, r) - to_linear = dora_lin.fuse() - self.assertTrue(mx.allclose(linear.weight, to_linear.weight)) - self.assertTrue(mx.allclose(linear.bias, to_linear.bias)) - - def dequantize_weight(quantized_linear): - return mx.dequantize( - quantized_linear.weight, - quantized_linear.scales, - quantized_linear.biases, - quantized_linear.group_size, - quantized_linear.bits, - ) - - quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True) - dora_quantized_linear = DoRALinear.from_base(quantized_linear, r) - # Dequantize - to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True) - self.assertTrue( - mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias) - ) - self.assertTrue( - mx.allclose( - dequantize_weight(quantized_linear), to_linear_from_quantized.weight - ) - ) - - def test_dora_dtype(self): - in_dims = 256 - out_dims = 256 - r = 4 - - linear = nn.Linear(in_dims, out_dims, bias=True) - linear.set_dtype(mx.float16) - dora_lin = DoRALinear.from_base(linear, r) - - x = mx.random.uniform(shape=(2, 256)).astype(mx.float16) - self.assertEqual(dora_lin(x).dtype, mx.float16) - - -class TestScheduleConfig(unittest.TestCase): - def test_join(self): - config = {"name": "cosine_decay", "warmup": 100, "arguments": [1e-5, 100]} - cos_with_warmup = build_schedule(config) - self.assertIsNotNone(cos_with_warmup) - - self.assertEqual(cos_with_warmup(0), 0.0) - self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1) - optimizer = opt.Adam(learning_rate=cos_with_warmup) - for _ in range(100): - optimizer.update({}, {}) - self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1) - for _ in range(100): - optimizer.update({}, {}) - expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10)) - self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1) - - def test_single_schedule(self): - - config = { - "name": "cosine_decay", - "arguments": [0.1, 10], - } - lr_schedule = build_schedule(config) - lr = lr_schedule(4) - expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) - self.assertAlmostEqual(lr, expected_lr, delta=1e-7) - - def test_non_zero_warmup(self): - config = { - "name": "cosine_decay", - "warmup": 10, - "warmup_init": 1e-6, - "arguments": [1e-5, 20], - } - lr_schedule = build_schedule(config) - lr = lr_schedule(0) - self.assertAlmostEqual(lr, 1e-6, delta=1e-7) - - def test_malformed_config(self): - config = {"warmup": 100} - self.assertRaises(KeyError, build_schedule, config) - - config = {"cosine_decay": None} - self.assertRaises(KeyError, build_schedule, config) - - def test_evaluate_calls(self): - mock_model = MagicMock() - mock_dataset = MagicMock() - mock_tokenizer = MagicMock() - mock_default_loss = MagicMock() - mock_iterate_batches = MagicMock() - - mock_iterate_batches.return_value = [ - (MagicMock(), MagicMock()), - (MagicMock(), MagicMock()), - (MagicMock(), MagicMock()), - (MagicMock(), MagicMock()), - (MagicMock(), MagicMock()), - ] - - mock_default_loss.side_effect = [ - (MagicMock(return_value=0.5), MagicMock(return_value=100)), - (MagicMock(return_value=0.3), MagicMock(return_value=200)), - (MagicMock(return_value=0.2), MagicMock(return_value=150)), - (MagicMock(return_value=0.4), MagicMock(return_value=180)), - (MagicMock(return_value=0.6), MagicMock(return_value=120)), - ] - with swapped_with_identity(mx.distributed, "all_sum"): - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=2, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) - - mock_iterate_batches.assert_called_once_with( - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - max_seq_length=2048, - ) - self.assertEqual(mock_default_loss.call_count, 2) - - def test_evaluate_infinite_batches(self): - mock_model = MagicMock() - mock_dataset = MagicMock() - mock_tokenizer = MagicMock() - mock_default_loss = MagicMock() - mock_iterate_batches = MagicMock() - - mock_iterate_batches.return_value = [ - (MagicMock(), MagicMock()), - (MagicMock(), MagicMock()), - (MagicMock(), MagicMock()), - ] - - mock_default_loss.side_effect = [ - (MagicMock(return_value=0.5), MagicMock(return_value=100)), - (MagicMock(return_value=0.3), MagicMock(return_value=200)), - (MagicMock(return_value=0.2), MagicMock(return_value=150)), - ] - - with swapped_with_identity(mx.distributed, "all_sum"): - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=-1, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) - - mock_iterate_batches.assert_called_once_with( - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - max_seq_length=2048, - ) - self.assertEqual(mock_default_loss.call_count, 3) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py deleted file mode 100644 index 7445a9b9..00000000 --- a/llms/tests/test_generate.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import unittest -from typing import List - -from mlx_lm.sample_utils import make_logits_processors -from mlx_lm.utils import ( - GenerationResponse, - generate, - load, - make_sampler, - stream_generate, -) - - -class TestGenerate(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH) - - def test_generate(self): - # Simple test that generation runs - text = generate( - self.model, self.tokenizer, "hello", max_tokens=5, verbose=False - ) - - def test_generate_with_logit_bias(self): - logit_bias = {0: 2000.0, 1: -20.0} - text = generate( - self.model, - self.tokenizer, - "hello", - max_tokens=5, - logits_processors=make_logits_processors(logit_bias), - verbose=False, - ) - self.assertEqual(text, "!!!!!") - - def test_generate_with_processor(self): - init_toks = self.tokenizer.encode("hello") - - all_toks = None - - def logits_processor(toks, logits): - nonlocal all_toks - all_toks = toks - return logits - - generate( - self.model, - self.tokenizer, - "hello", - max_tokens=5, - verbose=False, - logits_processors=[logits_processor], - ) - self.assertEqual(len(all_toks), len(init_toks) + 5) - - def test_stream_generate_speculative(self): - # Use same model as draft model, this is not a speed test - draft_model, _ = load(self.HF_MODEL_PATH) - - results: List[GenerationResponse] = [] - drafted: List[bool] = [] - - # make a determinate sampler - sampler = make_sampler(temp=0.0) - - for generation_result in stream_generate( - model=self.model, - tokenizer=self.tokenizer, - prompt="hello", - max_tokens=5, - draft_model=draft_model, - num_draft_tokens=2, - sampler=sampler, - ): - drafted.append(generation_result.from_draft) - results.append(generation_result) - - self.assertEqual(len(results), 5) - # since num_draft_tokens is 2 and draft model is the same, the - # first 2 generations should be drafts, the third should come - # from the target model, and last two should be drafts - self.assertEqual(drafted, [True, True, False, True, True]) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_gguf.py b/llms/tests/test_gguf.py deleted file mode 100644 index 24ca64aa..00000000 --- a/llms/tests/test_gguf.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import tempfile -import unittest -from pathlib import Path -from unittest.mock import MagicMock, patch - -import mlx.core as mx -from mlx_lm.gguf import convert_to_gguf - - -class TestConvertToGGUFWithoutMocks(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.test_dir_fid = tempfile.TemporaryDirectory() - cls.test_dir = cls.test_dir_fid.name - cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json") - with open(cls.tokenizer_file_path, "w") as f: - f.write("{}") - - @classmethod - def tearDownClass(cls): - cls.test_dir_fid.cleanup() - - @patch("transformers.AutoTokenizer.from_pretrained") - @patch("mlx.core.save_gguf") - def test_convert_to_gguf( - self, - mock_save_gguf, - mock_from_pretrained, - ): - mock_tokenizer = MagicMock() - mock_tokenizer.vocab_size = 3 - mock_tokenizer.get_added_vocab.return_value = {} - mock_tokenizer.get_vocab.return_value = {"": 0, "hello": 1, "world": 2} - mock_tokenizer.all_special_tokens = [""] - mock_tokenizer.all_special_ids = [0] - mock_from_pretrained.return_value = mock_tokenizer - - model_path = Path(self.test_dir) - weights = { - "self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]), - } - config = { - "num_attention_heads": 1, - "num_hidden_layers": 1, - "hidden_size": 768, - "intermediate_size": 3072, - "_name_or_path": "test-llama", - } - output_file_path = "/fake/output/path/gguf_model.gguf" - - convert_to_gguf(model_path, weights, config, output_file_path) - called_args, _ = mock_save_gguf.call_args - self.assertEqual(called_args[0], output_file_path) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py deleted file mode 100644 index b4e7aab8..00000000 --- a/llms/tests/test_models.py +++ /dev/null @@ -1,986 +0,0 @@ -# Copyright © 2024 Apple Inc. -import unittest - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_map -from mlx_lm.models import rope_utils -from mlx_lm.models.base import create_causal_mask -from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache - - -class TestModels(unittest.TestCase): - - def test_kv_cache(self): - cache = KVCache() - - k = mx.ones((1, 4, 1, 32), mx.float16) - v = mx.ones((1, 4, 1, 32), mx.float16) - - k_up, v_up = cache.update_and_fetch(k, v) - self.assertTrue(mx.array_equal(k_up, k)) - self.assertTrue(mx.array_equal(v_up, v)) - self.assertEqual(cache.offset, 1) - - k = mx.ones((1, 4, cache.step, 32), mx.float16) - v = mx.ones((1, 4, cache.step, 32), mx.float16) - k_up, v_up = cache.update_and_fetch(k, v) - - expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16) - self.assertTrue(mx.array_equal(k_up, expected)) - self.assertTrue(mx.array_equal(v_up, expected)) - self.assertEqual(cache.offset, cache.step + 1) - - def test_rotating_kv_cache(self): - b, h, d = 1, 2, 32 - cache = RotatingKVCache(max_size=8, step=4) - - k = mx.random.uniform(shape=(b, h, 2, d)) - v = mx.random.uniform(shape=(b, h, 2, d)) - - k_up, v_up = cache.update_and_fetch(k, v) - self.assertTrue(mx.array_equal(k_up, k)) - self.assertTrue(mx.array_equal(v_up, v)) - self.assertEqual(cache.offset, 2) - - k = mx.random.uniform(shape=(b, h, 5, d)) - v = mx.random.uniform(shape=(b, h, 5, d)) - k_up, v_up = cache.update_and_fetch(k, v) - self.assertTrue(mx.array_equal(k_up[..., 2:, :], k)) - self.assertTrue(mx.array_equal(v_up[..., 2:, :], v)) - - k = mx.random.uniform(shape=(b, h, 4, d)) - v = mx.random.uniform(shape=(b, h, 4, d)) - k_up, v_up = cache.update_and_fetch(k, v) - self.assertTrue(mx.array_equal(k_up[..., -4:, :], k)) - self.assertTrue(mx.array_equal(v_up[..., -4:, :], v)) - - idx = 0 - for _ in range(10): - k = mx.random.uniform(shape=(b, h, 1, d)) - v = mx.random.uniform(shape=(b, h, 1, d)) - k_up, v_up = cache.update_and_fetch(k, v) - self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k)) - self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v)) - idx += 1 - idx %= 8 - - # Try with nonzero keep - cache = RotatingKVCache(max_size=8, step=4, keep=2) - - # Check a large update - k = mx.random.uniform(shape=(b, h, 20, d)) - v = mx.random.uniform(shape=(b, h, 20, d)) - k_up, v_up = cache.update_and_fetch(k, v) - self.assertTrue(mx.array_equal(k_up, k)) - self.assertTrue(mx.array_equal(v_up, v)) - - # A bunch of small updates - self.assertEqual(cache.offset, 20) - idx = 2 - for i in range(10): - k = mx.random.uniform(shape=(b, h, 1, d)) - v = mx.random.uniform(shape=(b, h, 1, d)) - k_up, v_up = cache.update_and_fetch(k, v) - self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k)) - self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v)) - self.assertEqual(cache.offset, 21 + i) - idx += 1 - if idx >= 8: - idx = 2 - - def test_rotating_kv_cache_chat_mode(self): - # Test that the rotating kv cache can handle - # alternating prompt/prefill with generation - d = 4 - h = 2 - cache = RotatingKVCache(max_size=18, step=4) - - x = mx.random.uniform(shape=(1, h, 8, d)) - k, v = cache.update_and_fetch(x, x) - self.assertEqual(k.shape[2], 8) - self.assertEqual(cache.offset, 8) - - x = mx.random.uniform(shape=(1, h, 1, d)) - k, v = cache.update_and_fetch(x, x) - self.assertEqual(k.shape[2], 9) - self.assertEqual(cache.offset, 9) - self.assertTrue(mx.allclose(x, k[..., 8:9, :])) - - x = mx.random.uniform(shape=(1, h, 2, d)) - k, v = cache.update_and_fetch(x, x) - self.assertEqual(k.shape[2], 11) - self.assertEqual(cache.offset, 11) - self.assertTrue(mx.allclose(x, k[..., 9:11, :])) - - x = mx.random.uniform(shape=(1, h, 3, d)) - k, v = cache.update_and_fetch(x, x) - self.assertEqual(k.shape[2], 14) - self.assertEqual(cache.offset, 14) - self.assertTrue(mx.allclose(x, k[..., 11:14, :])) - - x = mx.random.uniform(shape=(1, h, 6, d)) - k, v = cache.update_and_fetch(x, x) - self.assertEqual(cache.offset, 20) - self.assertTrue(mx.allclose(x, k[..., -6:, :])) - - x = mx.random.uniform(shape=(1, h, 2, d)) - k, v = cache.update_and_fetch(x, x) - self.assertEqual(cache.offset, 22) - self.assertTrue(mx.allclose(x, k[..., -2:, :])) - - def test_causal_mask_lengths(self): - mx.random.seed(8) - B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2) - lengths = mx.array([1, 2, 3, 1]) - q = mx.random.uniform(shape=(B, N_q, T_q, D)) - k = mx.random.uniform(shape=(B, N_kv, T_kv, D)) - v = k - mask = create_causal_mask(T_q, 0, lengths=lengths) - - out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) - q[1, :, 2:] = mx.ones_like(q[1, :, 2:]) - k[1, :, 2:] = mx.ones_like(k[1, :, 2:]) - v[1, :, 2:] = mx.ones_like(v[1, :, 2:]) - out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) - self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2])) - - def test_rope(self): - rope = rope_utils.initialize_rope(32, base=100, traditional=False) - self.assertTrue(isinstance(rope, nn.RoPE)) - - rope = rope_utils.initialize_rope( - 32, - base=100, - traditional=False, - scaling_config={"rope_type": "linear", "factor": 10.0}, - ) - self.assertTrue(isinstance(rope, nn.RoPE)) - - rope = rope_utils.initialize_rope( - 32, - base=100, - traditional=False, - scaling_config={"rope_type": "llama3", "factor": 2.0}, - ) - self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE)) - - def model_test_runner(self, model, model_type, vocab_size, num_layers): - - self.assertEqual(len(model.layers), num_layers) - self.assertEqual(model.model_type, model_type) - - for t in [mx.float32, mx.float16]: - model.update(tree_map(lambda p: p.astype(t), model.parameters())) - - inputs = mx.array([[0, 1]]) - outputs = model(inputs) - self.assertEqual(outputs.shape, (1, 2, vocab_size)) - self.assertEqual(outputs.dtype, t) - - cache = make_prompt_cache(model) - outputs = model(inputs, cache=cache) - self.assertEqual(outputs.shape, (1, 2, vocab_size)) - self.assertEqual(outputs.dtype, t) - - if model_type not in ("mamba", "plamo2"): - mask = create_causal_mask(inputs.shape[1], 0).astype(t) - outputs = model(inputs, mask=mask) - self.assertEqual(outputs.shape, (1, 2, vocab_size)) - self.assertEqual(outputs.dtype, t) - - outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache) - self.assertEqual(outputs.shape, (1, 1, vocab_size)) - self.assertEqual(outputs.dtype, t) - - def test_llama(self): - from mlx_lm.models import llama - - args = llama.ModelArgs( - model_type="llama", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - ) - model = llama.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_phi2(self): - from mlx_lm.models import phi - - args = phi.ModelArgs() - model = phi.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_phixtral(self): - from mlx_lm.models import phixtral - - args = phixtral.ModelArgs( - "phixtral", num_vocab=1000, num_layers=4, model_dim=1024 - ) - model = phixtral.Model(args) - self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers) - - def test_phi3(self): - from mlx_lm.models import phi3 - - args = phi3.ModelArgs( - model_type="phi3", - hidden_size=3072, - num_hidden_layers=32, - intermediate_size=8192, - num_attention_heads=32, - rms_norm_eps=1e-5, - vocab_size=32064, - ) - model = phi3.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_gemma(self): - from mlx_lm.models import gemma - - args = gemma.ModelArgs( - model_type="gemma", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - head_dim=128, - rms_norm_eps=1e-5, - vocab_size=10_000, - num_key_value_heads=4, - ) - model = gemma.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_mixtral(self): - from mlx_lm.models import mixtral - - # Make a baby mixtral, because it will actually do the - # eval - args = mixtral.ModelArgs( - model_type="mixtral", - vocab_size=100, - hidden_size=32, - intermediate_size=128, - num_hidden_layers=2, - num_attention_heads=4, - num_experts_per_tok=2, - num_key_value_heads=2, - num_local_experts=4, - ) - model = mixtral.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - @unittest.skip("requires ai2-olmo") - def test_olmo(self): - from mlx_lm.models import olmo - - args = olmo.ModelArgs( - model_type="olmo", - d_model=1024, - n_layers=4, - mlp_hidden_size=2048, - n_heads=2, - vocab_size=10_000, - embedding_size=10_000, - ) - model = olmo.Model(args) - self.model_test_runner( - model, - args.model_type, - args.vocab_size, - args.n_layers, - ) - - def test_qwen2_moe(self): - from mlx_lm.models import qwen2_moe - - args = qwen2_moe.ModelArgs( - model_type="qwen2_moe", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - num_experts_per_tok=4, - num_experts=16, - moe_intermediate_size=1024, - shared_expert_intermediate_size=2048, - ) - model = qwen2_moe.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_qwen2(self): - from mlx_lm.models import qwen2 - - args = qwen2.ModelArgs( - model_type="qwen2", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - ) - model = qwen2.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_qwen(self): - from mlx_lm.models import qwen - - args = qwen.ModelArgs( - model_type="qwen", - ) - model = qwen.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_plamo(self): - from mlx_lm.models import plamo - - args = plamo.ModelArgs( - model_type="plamo", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=8, - rms_norm_eps=1e-5, - vocab_size=10_000, - ) - model = plamo.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_plamo2(self): - from mlx_lm.models import plamo2 - - args = plamo2.ModelArgs( - model_type="plamo2", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=8, - rms_norm_eps=1e-5, - vocab_size=10_000, - ) - model = plamo2.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_stablelm(self): - from mlx_lm.models import stablelm - - args = stablelm.ModelArgs( - model_type="stablelm", - vocab_size=10_000, - hidden_size=1024, - num_attention_heads=4, - num_hidden_layers=4, - num_key_value_heads=2, - partial_rotary_factor=1.0, - intermediate_size=2048, - layer_norm_eps=1e-2, - rope_theta=10_000, - use_qkv_bias=False, - ) - model = stablelm.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - # StableLM 2 - args = stablelm.ModelArgs( - model_type="stablelm", - vocab_size=10000, - hidden_size=512, - num_attention_heads=8, - num_hidden_layers=4, - num_key_value_heads=2, - partial_rotary_factor=0.25, - intermediate_size=1024, - layer_norm_eps=1e-5, - rope_theta=10000, - use_qkv_bias=True, - ) - model = stablelm.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_starcoder2(self): - from mlx_lm.models import starcoder2 - - args = starcoder2.ModelArgs( - model_type="starcoder2", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - num_key_value_heads=4, - ) - model = starcoder2.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_cohere(self): - from mlx_lm.models import cohere - - args = cohere.ModelArgs( - model_type="cohere", - ) - model = cohere.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_dbrx(self): - from mlx_lm.models import dbrx - - args = dbrx.ModelArgs( - model_type="dbrx", - d_model=1024, - ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2}, - attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000}, - n_layers=4, - n_heads=4, - vocab_size=10_000, - ) - model = dbrx.Model(args) - self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers) - - def test_minicpm(self): - from mlx_lm.models import minicpm - - args = minicpm.ModelArgs( - model_type="minicpm", - hidden_size=1024, - dim_model_base=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-4, - vocab_size=10000, - num_key_value_heads=2, - scale_depth=1.0, - scale_emb=1.0, - ) - model = minicpm.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_mamba(self): - from mlx_lm.models import mamba - - args = mamba.ModelArgs( - model_type="mamba", - vocab_size=10000, - use_bias=False, - use_conv_bias=True, - conv_kernel=4, - hidden_size=768, - num_hidden_layers=24, - state_size=16, - intermediate_size=1536, - time_step_rank=48, - ) - model = mamba.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_gpt2(self): - from mlx_lm.models import gpt2 - - args = gpt2.ModelArgs( - model_type="gpt2", - n_ctx=1024, - n_embd=768, - n_head=12, - n_layer=12, - n_positions=1024, - layer_norm_epsilon=1e-5, - vocab_size=50256, - ) - model = gpt2.Model(args) - self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer) - - def test_gpt_neox(self): - from mlx_lm.models import gpt_neox - - args = gpt_neox.ModelArgs( - model_type="gpt_neox", - max_position_embeddings=2048, - hidden_size=6144, - num_attention_heads=64, - num_hidden_layers=44, - layer_norm_eps=1e-5, - vocab_size=50432, - rotary_emb_base=10_000, - rotary_pct=0.25, - ) - model = gpt_neox.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_openelm(self): - from mlx_lm.models import openelm - - args = openelm.ModelArgs( - model_type="openelm", - ffn_dim_divisor=256, - ffn_multipliers=[ - 0.5, - 0.73, - 0.97, - 1.2, - 1.43, - 1.67, - 1.9, - 2.13, - 2.37, - 2.6, - 2.83, - 3.07, - 3.3, - 3.53, - 3.77, - 4.0, - ], - head_dim=64, - model_dim=1280, - normalize_qk_projections=True, - num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5], - num_query_heads=[ - 12, - 12, - 12, - 12, - 12, - 16, - 16, - 16, - 16, - 16, - 16, - 16, - 20, - 20, - 20, - 20, - ], - num_transformer_layers=16, - vocab_size=32000, - ) - - model = openelm.Model(args) - self.model_test_runner( - model, - args.model_type, - args.vocab_size, - len(args.ffn_multipliers), - ) - - def test_internlm2(self): - from mlx_lm.models import internlm2 - - args = internlm2.ModelArgs( - model_type="internlm2", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10000, - ) - model = internlm2.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_llama3_1(self): - from mlx_lm.models import llama - - args = llama.ModelArgs( - model_type="llama", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - max_position_embeddings=128, - mlp_bias=False, - num_key_value_heads=2, - rope_scaling={ - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3", - }, - ) - model = llama.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_deepseek(self): - from mlx_lm.models import deepseek - - args = deepseek.ModelArgs( - model_type="deepseek", - vocab_size=1024, - hidden_size=128, - intermediate_size=256, - moe_intermediate_size=256, - num_hidden_layers=4, - num_attention_heads=8, - num_key_value_heads=4, - ) - model = deepseek.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_deepseek_v2(self): - from mlx_lm.models import deepseek_v2 - - args = deepseek_v2.ModelArgs( - model_type="deepseek_v2", - vocab_size=1024, - hidden_size=128, - intermediate_size=256, - moe_intermediate_size=256, - num_hidden_layers=4, - num_attention_heads=4, - num_key_value_heads=2, - kv_lora_rank=4, - q_lora_rank=4, - qk_rope_head_dim=32, - v_head_dim=16, - qk_nope_head_dim=32, - rope_scaling={ - "beta_fast": 32, - "beta_slow": 1, - "factor": 40, - "mscale": 1.0, - "mscale_all_dim": 1.0, - "original_max_position_embeddings": 4096, - "type": "yarn", - }, - ) - model = deepseek_v2.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_deepseek_v3(self): - from mlx_lm.models import deepseek_v3 - - args = deepseek_v3.ModelArgs( - model_type="deepseek_v3", - vocab_size=1024, - hidden_size=128, - intermediate_size=256, - moe_intermediate_size=256, - num_hidden_layers=4, - num_attention_heads=4, - num_key_value_heads=2, - n_routed_experts=4, - n_group=2, - topk_group=1, - num_experts_per_tok=2, - n_shared_experts=1, - kv_lora_rank=4, - q_lora_rank=4, - qk_rope_head_dim=32, - v_head_dim=16, - qk_nope_head_dim=32, - rope_scaling={ - "beta_fast": 32, - "beta_slow": 1, - "factor": 40, - "mscale": 1.0, - "mscale_all_dim": 1.0, - "original_max_position_embeddings": 4096, - "type": "yarn", - }, - ) - model = deepseek_v3.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_gemma2(self): - from mlx_lm.models import gemma2 - - args = gemma2.ModelArgs( - model_type="gemma2", - hidden_size=128, - num_hidden_layers=4, - intermediate_size=256, - num_attention_heads=2, - head_dim=32, - rms_norm_eps=1e-4, - vocab_size=1024, - num_key_value_heads=2, - ) - model = gemma2.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_gemma3_text(self): - from mlx_lm.models import gemma3_text - - args = gemma3_text.ModelArgs( - model_type="gemma3_text", - hidden_size=128, - num_hidden_layers=12, - intermediate_size=256, - num_attention_heads=4, - head_dim=32, - rms_norm_eps=1e-4, - num_key_value_heads=1, - sliding_window=1024, - sliding_window_pattern=6, - ) - model = gemma3_text.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_gpt_bigcode(self): - from mlx_lm.models import gpt_bigcode - - args = gpt_bigcode.ModelArgs( - model_type="gpt_bigcode", - n_embd=128, - n_layer=128, - n_inner=256, - n_head=4, - n_positions=1000, - layer_norm_epsilon=1e-5, - vocab_size=1024, - ) - model = gpt_bigcode.Model(args) - self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer) - - def test_nemotron(self): - from mlx_lm.models import nemotron - - args = nemotron.ModelArgs( - model_type="nemotron", - hidden_size=128, - hidden_act="gelu", - num_hidden_layers=4, - intermediate_size=256, - num_attention_heads=4, - norm_eps=1e-5, - vocab_size=1024, - num_key_value_heads=2, - ) - model = nemotron.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_phi3small(self): - from mlx_lm.models import phi3small - - args = phi3small.ModelArgs( - model_type="phi3small", - hidden_size=128, - dense_attention_every_n_layers=2, - ff_intermediate_size=256, - gegelu_limit=1.0, - num_hidden_layers=4, - num_attention_heads=4, - num_key_value_heads=2, - layer_norm_epsilon=1e-4, - vocab_size=1000, - ) - model = phi3small.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_phimoe(self): - from mlx_lm.models import phimoe - - args = phimoe.ModelArgs( - model_type="phimoe", - vocab_size=320, - hidden_size=128, - intermediate_size=256, - num_hidden_layers=4, - num_attention_heads=4, - num_key_value_heads=4, - rope_scaling={ - "long_factor": [1.0] * 16, - "long_mscale": 1.243163121016122, - "original_max_position_embeddings": 4096, - "short_factor": [1.0] * 16, - "short_mscale": 1.243163121016122, - "type": "longrope", - }, - ) - model = phimoe.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_recurrent_gemma(self): - from mlx_lm.models import recurrent_gemma - - args = recurrent_gemma.ModelArgs( - model_type="recurrent_gemma", - hidden_size=128, - attention_bias=False, - conv1d_width=3, - intermediate_size=256, - logits_soft_cap=1.0, - num_attention_heads=4, - num_hidden_layers=4, - num_key_value_heads=2, - rms_norm_eps=1e-4, - rope_theta=1000, - attention_window_size=1024, - vocab_size=1000, - block_types=["recurrent", "recurrent", "attention"], - ) - model = recurrent_gemma.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_hunyuan(self): - from mlx_lm.models import hunyuan - - args = hunyuan.ModelArgs( - model_type="hunyuan", - hidden_size=128, - attention_bias=False, - intermediate_size=256, - num_attention_heads=4, - num_hidden_layers=4, - num_key_value_heads=2, - rms_norm_eps=1e-4, - rope_theta=1000, - vocab_size=1000, - moe_topk=2, - num_experts=2, - num_shared_expert=1, - use_mixed_mlp_moe=True, - use_qk_norm=True, - rope_scaling={ - "alpha": 1000.0, - "factor": 1.0, - "type": "dynamic", - }, - use_cla=True, - cla_share_factor=2, - ) - model = hunyuan.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_olmo2(self): - from mlx_lm.models import olmo2 - - args = olmo2.ModelArgs( - model_type="olmo2", - hidden_size=128, - attention_bias=False, - intermediate_size=256, - num_attention_heads=4, - num_hidden_layers=4, - num_key_value_heads=2, - rms_norm_eps=1e-4, - rope_theta=1000, - vocab_size=1000, - ) - model = olmo2.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_exaone(self): - from mlx_lm.models import exaone - - args = exaone.ModelArgs( - model_type="exaone", - hidden_size=128, - num_layers=4, - intermediate_size=256, - num_attention_heads=8, - num_key_value_heads=2, - vocab_size=1000, - layer_norm_epsilon=1e-4, - rope_theta=10000, - ) - model = exaone.Model(args) - self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers) - - def test_cohere2(self): - from mlx_lm.models import cohere2 - - args = cohere2.ModelArgs( - model_type="cohere2", - hidden_size=4096, - head_dim=128, - num_hidden_layers=40, - sliding_window=4096, - sliding_window_pattern=4, - ) - model = cohere2.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - def test_internlm3(self): - from mlx_lm.models import internlm3 - - args = internlm3.ModelArgs( - model_type="internlm3", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - ) - model = internlm3.Model(args) - self.model_test_runner( - model, args.model_type, args.vocab_size, args.num_hidden_layers - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py deleted file mode 100644 index c1860892..00000000 --- a/llms/tests/test_prompt_cache.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import copy -import os -import tempfile -import unittest - -import mlx.core as mx -from mlx_lm.models.cache import ( - KVCache, - MambaCache, - QuantizedKVCache, - RotatingKVCache, - load_prompt_cache, - make_prompt_cache, - save_prompt_cache, - trim_prompt_cache, -) -from mlx_lm.utils import generate_step, load - -HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - - -class TestPromptCache(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.test_dir_fid = tempfile.TemporaryDirectory() - cls.test_dir = cls.test_dir_fid.name - - @classmethod - def tearDownClass(cls): - cls.test_dir_fid.cleanup() - - def test_save_load(self): - cache = [KVCache() for _ in range(4)] - for c in cache: - x = mx.random.uniform(shape=(1, 8, 10, 4)) - c.update_and_fetch(x, x) - cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") - save_prompt_cache(cache_file, cache) - loaded_cache = load_prompt_cache(cache_file) - self.assertTrue(len(cache), len(loaded_cache)) - for c, lc in zip(cache, loaded_cache): - self.assertEqual(c.offset, lc.offset) - self.assertTrue(mx.array_equal(c.state[0], lc.state[0])) - self.assertTrue(mx.array_equal(c.state[1], lc.state[1])) - - # Test with metadata - cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") - metadata = {"a": "b", "c": "d"} - save_prompt_cache(cache_file, cache, metadata) - _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) - self.assertEqual(metadata, loaded_metadata) - - def test_save_load_rotating_cache(self): - cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") - - # Test with rotating cache - cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)] - for c in cache: - x = mx.random.uniform(shape=(1, 8, 10, 4)) - c.update_and_fetch(x, x) - - save_prompt_cache(cache_file, cache) - loaded_cache = load_prompt_cache(cache_file) - self.assertTrue(len(cache), len(loaded_cache)) - for c, lc in zip(cache, loaded_cache): - self.assertEqual(c.offset, lc.offset) - self.assertEqual(c.keep, lc.keep) - self.assertEqual(c.max_size, lc.max_size) - self.assertEqual(c.step, lc.step) - self.assertTrue(mx.array_equal(c.state[0], lc.state[0])) - self.assertTrue(mx.array_equal(c.state[1], lc.state[1])) - - # Do a couple single token updates to get a rotation - for _ in range(2): - for c in cache: - x = mx.random.uniform(shape=(1, 8, 1, 4)) - c.update_and_fetch(x, x) - - save_prompt_cache(cache_file, cache) - loaded_cache = load_prompt_cache(cache_file) - - for c, lc in zip(cache, loaded_cache): - x = mx.random.uniform(shape=(1, 8, 1, 4)) - k, v = c.update_and_fetch(x, x) - lk, lv = lc.update_and_fetch(x, x) - self.assertEqual(c.offset, lc.offset) - self.assertTrue(mx.array_equal(k, lk)) - self.assertTrue(mx.array_equal(v, lv)) - - def test_save_load_mixed_cache(self): - cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") - - cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()] - for c in cache: - if isinstance(c, MambaCache): - c[0] = mx.random.uniform(shape=(4, 4, 4)) - c[1] = mx.random.uniform(shape=(4, 4, 4)) - else: - x = mx.random.uniform(shape=(4, 4, 7, 4)) - y = mx.random.uniform(shape=(4, 4, 7, 4)) - c.update_and_fetch(x, y) - - save_prompt_cache(cache_file, cache) - loaded_cache = load_prompt_cache(cache_file) - for c, lc in zip(cache, loaded_cache): - if isinstance(c, MambaCache): - self.assertTrue(mx.array_equal(c[0], lc[0])) - self.assertTrue(mx.array_equal(c[1], lc[1])) - else: - x = mx.random.uniform(shape=(4, 4, 1, 4)) - y = mx.random.uniform(shape=(4, 4, 1, 4)) - k, v = c.update_and_fetch(x, y) - lk, lv = lc.update_and_fetch(x, y) - self.assertEqual(c.offset, lc.offset) - self.assertTrue(mx.array_equal(k, lk)) - self.assertTrue(mx.array_equal(v, lv)) - - def test_cache_with_generate(self): - model, tokenizer = load(HF_MODEL_PATH) - prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = list(generate_step(prompt, model, max_tokens=4)) - toks, all_logits = zip(*results) - - prompt_cache = make_prompt_cache(model) - i = 0 - for tok, logits in generate_step( - prompt, model, prompt_cache=prompt_cache, max_tokens=2 - ): - self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i])) - i += 1 - - for tok, logits in generate_step( - mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1 - ): - i += 1 - self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i])) - - def test_trim_cache(self): - cache = [KVCache() for _ in range(2)] - for c in cache: - x = mx.random.uniform(shape=(1, 8, 10, 4)) - c.update_and_fetch(x, x) - - # Trim - num_trimmed = trim_prompt_cache(cache, 7) - self.assertEqual(num_trimmed, 7) - - # Trim more tokens than remain - num_trimmed = trim_prompt_cache(cache, 4) - self.assertEqual(num_trimmed, 3) - - # Can't trim mamba cache - cache = [MambaCache() for _ in range(2)] - for c in cache: - c.state = mx.zeros((5, 5)) - num_trimmed = trim_prompt_cache(cache, 7) - self.assertEqual(num_trimmed, 0) - - # All cache's have to be trimmable - cache = [MambaCache(), KVCache()] - cache[0].state = mx.zeros((5, 5)) - x = mx.random.uniform(shape=(1, 8, 10, 4)) - cache[1].update_and_fetch(x, x) - num_trimmed = trim_prompt_cache(cache, 1) - self.assertEqual(num_trimmed, 0) - - cache = [RotatingKVCache(max_size=6) for _ in range(2)] - for c in cache: - x = mx.random.uniform(shape=(1, 8, 5, 4)) - c.update_and_fetch(x, x) - - num_trimmed = trim_prompt_cache(cache, 4) - self.assertEqual(num_trimmed, 4) - - # Can't trim fixed-size KV cache after processing - # more than max_kv_size tokens - for c in cache: - x = mx.random.uniform(shape=(1, 8, 10, 4)) - c.update_and_fetch(x, x) - - num_trimmed = trim_prompt_cache(cache, 4) - self.assertEqual(num_trimmed, 0) - - cache = [QuantizedKVCache() for _ in range(2)] - for c in cache: - x = mx.random.uniform(shape=(1, 8, 10, 64)) - c.update_and_fetch(x, x) - - num_trimmed = trim_prompt_cache(cache, 7) - self.assertEqual(num_trimmed, 7) - - # Trim more tokens than remain - num_trimmed = trim_prompt_cache(cache, 4) - self.assertEqual(num_trimmed, 3) - - def test_trim_cache_with_generate(self): - model, tokenizer = load(HF_MODEL_PATH) - prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - - prompt_cache = make_prompt_cache(model) - - # Generate one token so we process the full prompt - last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache)) - last_tok = mx.array([last_tok]) - - # Generate two more tokens - results = zip( - range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) - ) - toks, all_logits = zip(*(r[1] for r in results)) - - # To get back to the cache just after processing the prompt, - # trim by 3 tokens - trim_prompt_cache(prompt_cache, 3) - - # Generate the same thing again - results = zip( - range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) - ) - second_toks, second_all_logits = zip(*(r[1] for r in results)) - self.assertEqual(toks, second_toks) - self.assertTrue( - all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) - ) - - def test_cache_copying(self): - cache = [KVCache()] - - x = mx.random.uniform(shape=(1, 8, 10, 4)) - cache[0].update_and_fetch(x, x) - - y = mx.random.uniform(shape=(1, 8, 1, 4)) - cache[0].update_and_fetch(y, y) - - old_cache = copy.deepcopy(cache) - - trim_prompt_cache(cache, 1) - - self.assertTrue(old_cache[0].offset, 11) - self.assertTrue(cache[0].offset, 10) - - z = mx.random.uniform(shape=(1, 8, 1, 4)) - cache[0].update_and_fetch(z, z) - - self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) - self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) - - def test_save_load_quantized_cache(self): - cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)] - for c in cache: - x = mx.random.uniform(shape=(1, 8, 10, 32)) - c.update_and_fetch(x, x) - cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") - save_prompt_cache(cache_file, cache) - loaded_cache = load_prompt_cache(cache_file) - self.assertTrue(loaded_cache[0].bits == cache[0].bits) - self.assertTrue(loaded_cache[0].group_size == cache[0].group_size) - self.assertTrue(len(cache), len(loaded_cache)) - for c, lc in zip(cache, loaded_cache): - self.assertEqual(c.offset, lc.offset) - # Loop over quantized tuple - for i in range(3): - self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i])) - self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i])) - - # Test with metadata - cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") - metadata = {"a": "b", "c": "d"} - save_prompt_cache(cache_file, cache, metadata) - _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) - self.assertEqual(metadata, loaded_metadata) - - def test_cache_to_quantized(self): - model, tokenizer = load(HF_MODEL_PATH) - prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = zip(range(4), generate_step(prompt, model)) - toks, all_logits = zip(*(r[1] for r in results)) - - prompt_cache = make_prompt_cache(model) - i = 0 - for _, (tok, logits) in zip( - range(2), generate_step(prompt, model, prompt_cache=prompt_cache) - ): - self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i])) - i += 1 - - prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache] - - for _, (tok, logits) in zip( - range(1), - generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), - ): - i += 1 - self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i], rtol=3e-2)) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py deleted file mode 100644 index 7760c569..00000000 --- a/llms/tests/test_sample_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import unittest - -import mlx.core as mx -from mlx_lm.sample_utils import apply_min_p, apply_top_k, apply_top_p - - -class TestSampleUtils(unittest.TestCase): - def test_apply_top_p(self): - probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] - logits = mx.log(probs) - - new_logits = apply_top_p(logits, 0.3) - actual_probs = mx.softmax(new_logits.squeeze()) - self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) - - new_logits = apply_top_p(logits, 0.95) - actual_probs = mx.softmax(new_logits.squeeze()) - self.assertTrue(mx.allclose(probs.squeeze(), actual_probs)) - - probs = mx.array([0.0, 0.5, 0.4, 0.1])[None] - logits = mx.log(probs) - new_logits = apply_top_p(logits, 0.4) - actual_probs = mx.softmax(new_logits.squeeze()) - self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0]) - - new_logits = apply_top_p(logits, 0.6) - actual_probs = mx.softmax(new_logits.squeeze()) - self.assertEqual( - [round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0] - ) - - new_logits = apply_top_p(logits, 0.95) - actual_probs = mx.softmax(new_logits.squeeze()) - actual_rounded = [round(p, 4) for p in actual_probs.tolist()] - expected_rounded = [0.0, 0.5, 0.4, 0.1] - self.assertEqual(actual_rounded, expected_rounded) - self.assertAlmostEqual(sum(actual_probs.tolist()), 1.0) - - # Batch mode works - probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]]) - logits = mx.log(probs) - new_logits = apply_top_p(logits, 0.5) - actual_probs = mx.softmax(new_logits, axis=-1) - self.assertEqual( - actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] - ) - - def test_apply_min_p(self): - probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] - logits = mx.log(probs) - new_logits = apply_min_p(logits, 0.8) - actual_probs = mx.softmax(new_logits.squeeze()) - self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) - - probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] - logits = mx.log(probs) - new_logits = apply_min_p(logits, 0.05) - actual_probs = mx.softmax(new_logits.squeeze()) - self.assertTrue(mx.allclose(actual_probs, mx.squeeze(probs))) - - # Batch mode works - probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) - logits = mx.log(probs) - new_logits = apply_min_p(logits, 0.7) - actual_probs = mx.softmax(new_logits, axis=-1) - self.assertEqual( - actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] - ) - - def test_apply_top_k(self): - probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] - logits = mx.log(probs) - - new_logits = apply_top_k(logits, 1) - actual_probs = mx.softmax(new_logits.squeeze()) - self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) - - probs = mx.array([0.6, 0.0, 0.1, 0.3])[None] - logits = mx.log(probs) - new_logits = apply_top_k(logits, 2) - actual_probs = mx.softmax(new_logits.squeeze()) - self.assertEqual( - [round(p, 4) for p in actual_probs.tolist()], [0.6667, 0.0, 0.0, 0.3333] - ) - - # Batch mode works - probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) - logits = mx.log(probs) - - new_logits = apply_top_k(logits, 1) - actual_probs = mx.softmax(new_logits, axis=-1) - self.assertEqual( - actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py deleted file mode 100644 index ecf95f78..00000000 --- a/llms/tests/test_server.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import http -import json -import threading -import unittest - -import requests -from mlx_lm.server import APIHandler -from mlx_lm.utils import load - - -class DummyModelProvider: - def __init__(self): - HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - self.model, self.tokenizer = load(HF_MODEL_PATH) - self.model_key = (HF_MODEL_PATH, None) - - def load(self, model, adapter=None): - assert model in ["default_model", "chat_model"] - return self.model, self.tokenizer - - -class TestServer(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model_provider = DummyModelProvider() - cls.server_address = ("localhost", 0) - cls.httpd = http.server.HTTPServer( - cls.server_address, - lambda *args, **kwargs: APIHandler(cls.model_provider, *args, **kwargs), - ) - cls.port = cls.httpd.server_port - cls.server_thread = threading.Thread(target=cls.httpd.serve_forever) - cls.server_thread.daemon = True - cls.server_thread.start() - - @classmethod - def tearDownClass(cls): - cls.httpd.shutdown() - cls.httpd.server_close() - cls.server_thread.join() - - def test_handle_completions(self): - url = f"http://localhost:{self.port}/v1/completions" - - post_data = { - "model": "default_model", - "prompt": "Once upon a time", - "max_tokens": 10, - "temperature": 0.5, - "top_p": 0.9, - "repetition_penalty": 1.1, - "repetition_context_size": 20, - "stop": "stop sequence", - } - - response = requests.post(url, json=post_data) - - response_body = response.text - - self.assertIn("id", response_body) - self.assertIn("choices", response_body) - - def test_handle_chat_completions(self): - url = f"http://localhost:{self.port}/v1/chat/completions" - chat_post_data = { - "model": "chat_model", - "max_tokens": 10, - "temperature": 0.7, - "top_p": 0.85, - "repetition_penalty": 1.2, - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello!"}, - ], - } - response = requests.post(url, json=chat_post_data) - response_body = response.text - self.assertIn("id", response_body) - self.assertIn("choices", response_body) - - def test_handle_chat_completions_with_content_fragments(self): - url = f"http://localhost:{self.port}/v1/chat/completions" - chat_post_data = { - "model": "chat_model", - "max_tokens": 10, - "temperature": 0.7, - "top_p": 0.85, - "repetition_penalty": 1.2, - "messages": [ - { - "role": "system", - "content": [ - {"type": "text", "text": "You are a helpful assistant."} - ], - }, - {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, - ], - } - response = requests.post(url, json=chat_post_data) - response_body = response.text - self.assertIn("id", response_body) - self.assertIn("choices", response_body) - - def test_handle_models(self): - url = f"http://localhost:{self.port}/v1/models" - response = requests.get(url) - self.assertEqual(response.status_code, 200) - response_body = json.loads(response.text) - self.assertEqual(response_body["object"], "list") - self.assertIsInstance(response_body["data"], list) - self.assertGreater(len(response_body["data"]), 0) - model = response_body["data"][0] - self.assertIn("id", model) - self.assertEqual(model["object"], "model") - self.assertIn("created", model) - - def test_sequence_overlap(self): - from mlx_lm.server import sequence_overlap - - self.assertTrue(sequence_overlap([1], [1])) - self.assertTrue(sequence_overlap([1, 2], [1, 2])) - self.assertTrue(sequence_overlap([1, 3], [3, 4])) - self.assertTrue(sequence_overlap([1, 2, 3], [2, 3])) - - self.assertFalse(sequence_overlap([1], [2])) - self.assertFalse(sequence_overlap([1, 2], [3, 4])) - self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3])) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py deleted file mode 100644 index 3009d1b1..00000000 --- a/llms/tests/test_tokenizers.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import unittest -from pathlib import Path - -from huggingface_hub import snapshot_download -from mlx_lm.tokenizer_utils import ( - BPEStreamingDetokenizer, - NaiveStreamingDetokenizer, - SPMStreamingDetokenizer, - load_tokenizer, -) - - -class TestTokenizers(unittest.TestCase): - - def download_tokenizer(self, repo): - path = Path( - snapshot_download( - repo_id=repo, - allow_patterns=[ - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - "tokenizer.model", - ], - ) - ) - return load_tokenizer(path) - - def check_tokenizer(self, tokenizer): - def check(tokens): - expected_text = tokenizer.decode(tokens) - detokenizer = tokenizer.detokenizer - detokenizer.reset() - text = "" - for e, t in enumerate(tokens): - detokenizer.add_token(t) - seg = detokenizer.last_segment - text += seg - self.assertEqual(detokenizer.tokens, tokens[: e + 1]) - detokenizer.finalize() - text += detokenizer.last_segment - self.assertEqual(text, expected_text) - - tokens = tokenizer.encode("こんにちは!私の名前はAI") - check(tokens) - - tokens = tokenizer.encode("a ,b") - check(tokens) - - tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}') - check(tokens) - - tokens = tokenizer.encode("3 3") - check(tokens) - - tokens = tokenizer.encode("import 'package:flutter/material.dart';") - check(tokens) - - tokens = tokenizer.encode("hello\nworld") - check(tokens) - - def test_tokenizers(self): - tokenizer_repos = [ - ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), - ("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer), - ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), - ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), - ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), - ("mlx-community/Falcon3-7B-Instruct-4bit", BPEStreamingDetokenizer), - ] - for tokenizer_repo, expected_detokenizer in tokenizer_repos: - with self.subTest(tokenizer=tokenizer_repo): - tokenizer = self.download_tokenizer(tokenizer_repo) - tokenizer.decode([0, 1, 2]) - self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer)) - self.check_tokenizer(tokenizer) - - # Try one with a naive detokenizer - tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit") - tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) - self.check_tokenizer(tokenizer) - - def test_special_tokens(self): - tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" - tokenizer = self.download_tokenizer(tokenizer_repo) - - detokenizer = tokenizer.detokenizer - detokenizer.reset() - detokenizer.add_token(tokenizer.eos_token_id) - detokenizer.finalize() - - self.assertEqual(detokenizer.last_segment, tokenizer.eos_token) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_tuner_utils.py b/llms/tests/test_tuner_utils.py deleted file mode 100644 index 0256683c..00000000 --- a/llms/tests/test_tuner_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import sys -import unittest -from io import StringIO -from unittest.mock import MagicMock - -import mlx.nn as nn -from mlx_lm.tuner.lora import LoRALinear -from mlx_lm.tuner.utils import print_trainable_parameters - - -class TestTunerUtils(unittest.TestCase): - def setUp(self): - self.capturedOutput = StringIO() - sys.stdout = self.capturedOutput - - def tearDown(self): - sys.stdout = sys.__stdout__ - - def test_quantized_print_trainable_parameters(self): - model = MagicMock() - quantized_linear = MagicMock(spec=nn.QuantizedLinear) - quantized_linear.weight = MagicMock(size=1e6) - quantized_linear.bits = 8 - lora_linear = MagicMock(spec=LoRALinear) - lora_linear.weight = MagicMock(size=2e6) - lora_linear.parameters.return_value = [lora_linear.weight] - - linear = MagicMock(spec=nn.Linear) - linear.weight = MagicMock(size=3e6) - linear.parameters.return_value = [linear.weight] - - model.leaf_modules.return_value = { - "quantized_linear": quantized_linear, - "lora_linear": lora_linear, - "linear": linear, - } - - model.trainable_parameters.return_value = { - "layer1.weight": MagicMock(size=1e6), - "layer3.weight": MagicMock(size=2e6), - } - expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n" - print_trainable_parameters(model) - self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits) - self.capturedOutput.truncate(0) - self.capturedOutput.seek(0) - - quantized_linear.weight = MagicMock(size=1e6) - quantized_linear.bits = 4 - expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n" - print_trainable_parameters(model) - self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits) - self.capturedOutput.truncate(0) - self.capturedOutput.seek(0) - - def test_print_trainable_parameters(self): - model = MagicMock() - linear1 = MagicMock(spec=nn.Linear) - linear1.weight = MagicMock(size=1e6) - linear1.parameters.return_value = [linear1.weight] - linear2 = MagicMock(spec=nn.Linear) - linear2.weight = MagicMock(size=2e6) - linear2.parameters.return_value = [linear2.weight] - lora_linear = MagicMock(spec=LoRALinear) - lora_linear.weight = MagicMock(size=3e6) - lora_linear.parameters.return_value = [lora_linear.weight] - model.leaf_modules.return_value = { - "linear1": linear1, - "linear2": linear2, - "lora_linear": lora_linear, - } - - model.trainable_parameters.return_value = { - "layer1.weight": MagicMock(size=1e6), - "layer3.weight": MagicMock(size=2e6), - } - expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n" - print_trainable_parameters(model) - self.assertEqual(self.capturedOutput.getvalue(), expected_output) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_utils.py b/llms/tests/test_utils.py deleted file mode 100644 index 18cfa8c7..00000000 --- a/llms/tests/test_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import os -import tempfile -import unittest - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_flatten -from mlx_lm import utils - -HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - - -class TestUtils(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.test_dir_fid = tempfile.TemporaryDirectory() - cls.test_dir = cls.test_dir_fid.name - if not os.path.isdir(cls.test_dir): - os.mkdir(cls.test_dir_fid.name) - - @classmethod - def tearDownClass(cls): - cls.test_dir_fid.cleanup() - - def test_load(self): - model, _ = utils.load(HF_MODEL_PATH) - - model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True) - - mx.eval(model_lazy.parameters()) - - p1 = model.layers[0].mlp.up_proj.weight - p2 = model_lazy.layers[0].mlp.up_proj.weight - self.assertTrue(mx.allclose(p1, p2)) - - def test_make_shards(self): - from mlx_lm.models import llama - - args = llama.ModelArgs( - model_type="llama", - hidden_size=2048, - num_hidden_layers=32, - intermediate_size=4096, - num_attention_heads=32, - rms_norm_eps=1e-5, - vocab_size=30_000, - ) - model = llama.Model(args) - weights = tree_flatten(model.parameters()) - gb = sum(p.nbytes for _, p in weights) // 2**30 - shards = utils.make_shards(dict(weights), 1) - self.assertTrue(gb <= len(shards) <= gb + 1) - - def test_quantize(self): - from mlx_lm.models import llama - - args = llama.ModelArgs( - model_type="llama", - hidden_size=1024, - num_hidden_layers=4, - intermediate_size=2048, - num_attention_heads=4, - rms_norm_eps=1e-5, - vocab_size=10_000, - ) - model = llama.Model(args) - weights, config = utils.quantize_model(model, {}, 64, 4) - self.assertTrue("model.layers.2.mlp.up_proj.scales" in weights) - self.assertTrue("model.layers.2.mlp.up_proj.biases" in weights) - self.assertEqual(config["quantization"]["group_size"], 64) - self.assertEqual(config["quantization"]["bits"], 4) - - def test_convert(self): - mlx_path = os.path.join(self.test_dir, "mlx_model") - - utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, quantize=True) - model, _ = utils.load(mlx_path) - self.assertTrue(isinstance(model.layers[0].mlp.up_proj, nn.QuantizedLinear)) - self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear)) - - # Check model weights have right type - mlx_path = os.path.join(self.test_dir, "mlx_model_bf16") - utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16") - model, _ = utils.load(mlx_path) - - self.assertEqual(model.layers[0].mlp.up_proj.weight.dtype, mx.bfloat16) - self.assertEqual(model.layers[-1].mlp.up_proj.weight.dtype, mx.bfloat16) - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py deleted file mode 100644 index 8da19afb..00000000 --- a/llms/tests/test_utils_load_model.py +++ /dev/null @@ -1,50 +0,0 @@ -import unittest -from pathlib import Path - -import mlx.nn as nn -from mlx_lm.models.qwen2 import Model as Qwen2Model -from mlx_lm.utils import get_model_path, load_model - -HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - - -class TestLoadModelCustomGetClasses(unittest.TestCase): - - def test_load_model_with_custom_get_classes(self): - class CustomQwenModel(nn.Module): - def __init__(self, args): - super().__init__() - self.config = args - self.custom_attribute = "This is a custom model" - - def load_weights(self, weights, **kwargs): - self.qwenWeights = weights - - class CustomQwenConfig: - @classmethod - def from_dict(cls, config): - instance = cls() - for k, v in config.items(): - setattr(instance, k, v) - return instance - - def custom_get_classes(config): - return CustomQwenModel, CustomQwenConfig - - model_path = get_model_path(HF_MODEL_PATH) - model, _ = load_model(model_path, get_model_classes=custom_get_classes) - - self.assertIsInstance(model, CustomQwenModel) - self.assertTrue(hasattr(model, "custom_attribute")) - self.assertEqual(model.custom_attribute, "This is a custom model") - self.assertTrue(hasattr(model, "qwenWeights")) - - def test_load_model_with_default_get_classes(self): - model_path = get_model_path(HF_MODEL_PATH) - model, _ = load_model(model_path) - - self.assertIsInstance(model, Qwen2Model) - - -if __name__ == "__main__": - unittest.main()