diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 2d0dcf60..15676360 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -222,6 +222,17 @@ data formats. Here are examples of these formats: } ``` + +The format for the `arguments` field in a function varies for different models. +Common formats include JSON strings and dictionaries. The example provided +follows the format used by +[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples) +and [Mistral +AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct). +A dictionary format is used in Hugging Face's [chat +templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example). +Refer to the documentation for the model you are fine-tuning for more details. + `completions`: @@ -241,7 +252,7 @@ each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than -> one example per line and do not split an example accross multiple lines. +> one example per line and do not split an example across multiple lines. ### Hugging Face Datasets diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index a3f43f71..9bac77a5 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--dtype", - help="Type to save the parameters, ignored if -q is given.", + help="Type to save the non-quantized parameters.", type=str, choices=["float16", "bfloat16", "float32"], default="float16", diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index ccc327a8..64951ae4 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -111,7 +111,7 @@ class MLP(nn.Module): self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index d2740dc1..84f498e9 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -205,7 +205,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: + if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) return weights diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 06a307a6..5595d311 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -440,7 +440,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv_1d.weight" in k and v.ndim == 3: + if "conv_1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head") diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4f872982..92741b68 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -720,7 +720,7 @@ def convert( model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) - dtype = mx.float16 if quantize else getattr(mx, dtype) + dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} if quantize and dequantize: