From 66e7bcb8866a050727849d9a303c54a0119f0f99 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 09:56:45 -0700 Subject: [PATCH 1/4] override dtype with quant (#1062) --- llms/mlx_lm/convert.py | 2 +- llms/mlx_lm/models/gemma2.py | 2 +- llms/mlx_lm/utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index a3f43f71..9bac77a5 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--dtype", - help="Type to save the parameters, ignored if -q is given.", + help="Type to save the non-quantized parameters.", type=str, choices=["float16", "bfloat16", "float32"], default="float16", diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index ccc327a8..64951ae4 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -111,7 +111,7 @@ class MLP(nn.Module): self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4f872982..92741b68 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -720,7 +720,7 @@ def convert( model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) - dtype = mx.float16 if quantize else getattr(mx, dtype) + dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} if quantize and dequantize: From d1d480867b2248fb95fedcf7f9d33b41689d9991 Mon Sep 17 00:00:00 2001 From: madroid Date: Wed, 23 Oct 2024 03:19:11 +0800 Subject: [PATCH 2/4] LoRA: update tools datasets docs (#1063) * LoRA: update tools datasets docs * nits * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 2d0dcf60..15676360 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -222,6 +222,17 @@ data formats. Here are examples of these formats: } ``` + +The format for the `arguments` field in a function varies for different models. +Common formats include JSON strings and dictionaries. The example provided +follows the format used by +[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples) +and [Mistral +AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct). +A dictionary format is used in Hugging Face's [chat +templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example). +Refer to the documentation for the model you are fine-tuning for more details. + `completions`: @@ -241,7 +252,7 @@ each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than -> one example per line and do not split an example accross multiple lines. +> one example per line and do not split an example across multiple lines. ### Hugging Face Datasets From 9000e280aeb56c2bcce128001ab157030095687a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 15:44:08 -0700 Subject: [PATCH 3/4] fix mamba models conversion (#1065) --- llms/mlx_lm/models/mamba.py | 2 +- llms/mlx_lm/models/recurrent_gemma.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index d2740dc1..84f498e9 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -205,7 +205,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: + if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) return weights diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 06a307a6..5595d311 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -440,7 +440,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv_1d.weight" in k and v.ndim == 3: + if "conv_1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head") From 4971462bf0dd7bba07d9f18fb0fd2752a51fde40 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 25 Oct 2024 05:56:17 +0100 Subject: [PATCH 4/4] feat(clip): add linear probe evaluation script (#960) --- clip/linear_probe.py | 56 +++++++++++++++++++++++++++++++++++++++++++ clip/requirements.txt | 1 + 2 files changed, 57 insertions(+) create mode 100644 clip/linear_probe.py diff --git a/clip/linear_probe.py b/clip/linear_probe.py new file mode 100644 index 00000000..2649e397 --- /dev/null +++ b/clip/linear_probe.py @@ -0,0 +1,56 @@ +# Mirror of the Linear Probe Evaluation Script +# from the official CLIP Repository. + +import mlx.core as mx +import numpy as np +from image_processor import CLIPImageProcessor +from mlx.data.datasets import load_cifar10 +from model import CLIPModel +from PIL import Image +from sklearn.linear_model import LogisticRegression +from tqdm import tqdm + + +def get_cifar10(batch_size, root=None): + tr = load_cifar10(root=root).batch(batch_size) + test = load_cifar10(root=root, train=False).batch(batch_size) + + return tr, test + + +def get_features(model, image_proc, iter): + all_features = [] + all_labels = [] + + for batch in tqdm(iter): + image, label = batch["image"], batch["label"] + x = image_proc([Image.fromarray(im) for im in image]) + y = mx.array(label) + + image_embeds = model.get_image_features(x) + mx.eval(image_embeds) + + all_features.append(image_embeds) + all_labels.append(y) + + return mx.concatenate(all_features), mx.concatenate(all_labels) + + +if __name__ == "__main__": + model = CLIPModel.from_pretrained("mlx_model") + image_proc = CLIPImageProcessor.from_pretrained("mlx_model") + + train_iter, test_iter = get_cifar10(batch_size=256) + train_features, train_labels = get_features(model, image_proc, train_iter) + test_features, test_labels = get_features(model, image_proc, test_iter) + + # Perform logistic regression + # NOTE: The value of C should be determined via a hyperparameter sweep + # using a validation split + classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) + classifier.fit(train_features, train_labels) + + # Evaluate using the logistic regression classifier + predictions = classifier.predict(test_features) + accuracy = (test_labels.squeeze() == predictions).mean().item() * 100 + print(f"Accuracy = {accuracy:.3f}") diff --git a/clip/requirements.txt b/clip/requirements.txt index 74f826ea..8e05620e 100644 --- a/clip/requirements.txt +++ b/clip/requirements.txt @@ -1,4 +1,5 @@ mlx +mlx-data numpy transformers torch