Merge branch 'ml-explore:main' into main

This commit is contained in:
锦此 2024-10-24 15:25:53 +08:00 committed by GitHub
commit a039f11b11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 17 additions and 6 deletions

View File

@ -222,6 +222,17 @@ data formats. Here are examples of these formats:
} }
``` ```
The format for the `arguments` field in a function varies for different models.
Common formats include JSON strings and dictionaries. The example provided
follows the format used by
[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples)
and [Mistral
AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct).
A dictionary format is used in Hugging Face's [chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example).
Refer to the documentation for the model you are fine-tuning for more details.
</details> </details>
`completions`: `completions`:
@ -241,7 +252,7 @@ each line not expected by the loader will be ignored.
> [!NOTE] > [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than > Each example in the datasets must be on a single line. Do not put more than
> one example per line and do not split an example accross multiple lines. > one example per line and do not split an example across multiple lines.
### Hugging Face Datasets ### Hugging Face Datasets

View File

@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser:
) )
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
help="Type to save the parameters, ignored if -q is given.", help="Type to save the non-quantized parameters.",
type=str, type=str,
choices=["float16", "bfloat16", "float32"], choices=["float16", "bfloat16", "float32"],
default="float16", default="float16",

View File

@ -111,7 +111,7 @@ class MLP(nn.Module):
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array: def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):

View File

@ -205,7 +205,7 @@ class Model(nn.Module):
def sanitize(self, weights): def sanitize(self, weights):
for k, v in weights.items(): for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3: if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1) weights[k] = v.moveaxis(2, 1)
return weights return weights

View File

@ -440,7 +440,7 @@ class Model(nn.Module):
def sanitize(self, weights): def sanitize(self, weights):
for k, v in weights.items(): for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3: if "conv_1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1) weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights: if "lm_head.weight" not in weights:
self.pop("lm_head") self.pop("lm_head")

View File

@ -720,7 +720,7 @@ def convert(
model, config, tokenizer = fetch_from_hub(model_path, lazy=True) model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters())) weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype) dtype = getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()} weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize and dequantize: if quantize and dequantize: