Merge branch 'main' into adding-full-finetuning

This commit is contained in:
Gökdeniz Gülmez
2024-09-28 19:54:37 +02:00
committed by GitHub
12 changed files with 474 additions and 27 deletions

View File

@@ -13,5 +13,5 @@ MLX Examples was developed with contributions from the following individuals:
- Gabrijel Boduljak: Implemented `CLIP`.
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything` model.
- Gökdeniz Gülmez: Added the `MiniCPM` model and support for full fine-tuning.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.

View File

@@ -173,8 +173,8 @@ For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory.
Currently, `*.jsonl` files support three data formats: `chat`,
`completions`, and `text`. Here are three examples of these formats:
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
data formats. Here are examples of these formats:
`chat`:
@@ -182,6 +182,58 @@ Currently, `*.jsonl` files support three data formats: `chat`,
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
```
`tools`:
```jsonl
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
```
<details>
<summary>View the expanded single data tool format</summary>
```jsonl
{
"messages": [
{ "role": "user", "content": "What is the weather in San Francisco?" },
{
"role": "assistant",
"tool_calls": [
{
"id": "call_id",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
}
}
]
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and country, eg. San Francisco, USA"
},
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
},
"required": ["location", "format"]
}
}
}
]
}
```
</details>
`completions`:
```jsonl
@@ -228,11 +280,13 @@ hf_dataset:
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat` and `completions` formats, Hugging Face [chat
templates](https://huggingface.co/blog/chat-templates) are used. This applies
the model's chat template by default. If the model does not have a chat
template, then Hugging Face will use a default. For example, the final text in
the `chat` example above with Hugging Face's default template becomes:
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
[chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
are used. This applies the model's chat template by default. If the model does
not have a chat template, then Hugging Face will use a default. For example,
the final text in the `chat` example above with Hugging Face's default template
becomes:
```text
<|im_start|>system

View File

@@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
### List Models
Use the `v1/models` endpoint to list available models:
```shell
curl localhost:8080/v1/models -H "Content-Type: application/json"
```
This will return a list of locally available models where each model in the
list contains the following fields:
- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.

231
llms/mlx_lm/models/mamba.py Normal file
View File

@@ -0,0 +1,231 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
state_size: int
num_hidden_layers: int
conv_kernel: int
use_bias: bool
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
self.hidden_size = self.d_model
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
self.intermediate_size = self.d_inner
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
self.state_size = self.d_state
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
self.num_hidden_layers = self.n_layer
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
self.num_hidden_layers = self.n_layers
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
self.conv_kernel = self.d_conv
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
self.use_bias = self.bias
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
self.use_conv_bias = self.conv_bias
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaCache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
groups, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.ssm_state_size = args.state_size
self.conv_kernel_size = args.conv_kernel
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
)
self.conv1d = DepthWiseConv1d(
channels=self.intermediate_size,
kernel_size=self.conv_kernel_size,
bias=self.use_conv_bias,
padding=self.conv_kernel_size - 1,
)
self.x_proj = nn.Linear(
self.intermediate_size,
self.time_step_rank + 2 * self.ssm_state_size,
bias=False,
)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
A = mx.repeat(
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
repeats=self.intermediate_size,
axis=0,
)
self.A_log = mx.log(A)
self.D = mx.ones([self.intermediate_size])
self.out_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.use_bias
)
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
)
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = MambaBlock(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from .utils import generate_step, load
@@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
def do_GET(self):
"""
Respond to a GET request from a client.
"""
if self.path == "/v1/models":
self.handle_models_request()
else:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
def handle_models_request(self):
"""
Handle a GET request for the /v1/models endpoint.
"""
self._set_completion_headers(200)
self.end_headers()
# Scan the cache directory for downloaded mlx models
hf_cache_info = scan_cache_dir()
downloaded_models = [
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
]
# Create a list of available models
models = [
{
"id": repo.repo_id,
"object": "model",
"created": self.created,
}
for repo in downloaded_models
]
response = {"object": "list", "data": models}
response_json = json.dumps(response).encode()
self.wfile.write(response_json)
self.wfile.flush()
def run(
host: str,

View File

@@ -36,7 +36,10 @@ class ChatDataset(Dataset):
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages,
tools=self._data[idx].get("tools", None),
tokenize=False,
add_generation_prompt=True,
)
return text

View File

@@ -100,9 +100,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] == tokenizer.eos_token_id:
print("[WARNING] Example already has an EOS token appended")
else:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch]

View File

@@ -52,7 +52,6 @@ def linear_to_lora_layers(
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
num_layers = len(model.layers)
if num_lora_layers < 0:
@@ -140,6 +139,15 @@ def linear_to_lora_layers(
"self_attn.kv_b_proj",
]
)
elif model.model_type == "mamba":
keys = set(
[
"mixer.in_proj",
"mixer.x_proj",
"mixer.dt_proj",
"mixer.out_proj",
]
)
else:
raise ValueError(f"Lora does not support {model.model_type}")

View File

@@ -154,10 +154,11 @@ def generate_step(
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -177,10 +178,13 @@ def generate_step(
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
A function that takes tokens and logits and returns the processed
logits. Default: ``None``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@@ -188,10 +192,6 @@ def generate_step(
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
@@ -214,6 +214,7 @@ def generate_step(
)
y = prompt
tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
@@ -233,11 +234,23 @@ def generate_step(
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def _step(y):
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
logits = logits_processor(tokens, logits)
if logit_bias:
logits[:, indices] += values
if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty

View File

@@ -0,0 +1,55 @@
# Copyright © 2024 Apple Inc.
import unittest
from mlx_lm.utils import generate, load
class TestGenerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
def test_generate(self):
# Simple test that generation runs
text = generate(
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
)
def test_generate_with_logit_bias(self):
logit_bias = {0: 2000.0, 1: -20.0}
text = generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logit_bias=logit_bias,
)
self.assertEqual(text, "!!!!!")
def test_generate_with_processor(self):
init_toks = self.tokenizer.encode("hello")
all_toks = None
def logits_processor(toks, logits):
nonlocal all_toks
all_toks = toks
return logits
generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logits_processor=logits_processor,
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
if __name__ == "__main__":
unittest.main()

View File

@@ -5,6 +5,7 @@ import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache, RotatingKVCache
from mlx_lm.utils import make_kv_caches
class TestModels(unittest.TestCase):
@@ -100,13 +101,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
cache = make_kv_caches(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
@@ -397,6 +392,26 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_mamba(self):
from mlx_lm.models import mamba
args = mamba.ModelArgs(
model_type="mamba",
vocab_size=10000,
use_bias=False,
use_conv_bias=True,
conv_kernel=4,
hidden_size=768,
num_hidden_layers=24,
state_size=16,
intermediate_size=1536,
time_step_rank=48,
)
model = mamba.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt2(self):
from mlx_lm.models import gpt2

View File

@@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc.
import http
import json
import threading
import unittest
@@ -77,6 +79,19 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url)
self.assertEqual(response.status_code, 200)
response_body = json.loads(response.text)
self.assertEqual(response_body["object"], "list")
self.assertIsInstance(response_body["data"], list)
self.assertGreater(len(response_body["data"]), 0)
model = response_body["data"][0]
self.assertIn("id", model)
self.assertEqual(model["object"], "model")
self.assertIn("created", model)
def test_sequence_overlap(self):
from mlx_lm.server import sequence_overlap