mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-11 03:36:42 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba2
This commit is contained in:
commit
3b70708201
56
clip/linear_probe.py
Normal file
56
clip/linear_probe.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Mirror of the Linear Probe Evaluation Script
|
||||
# from the official CLIP Repository.
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from image_processor import CLIPImageProcessor
|
||||
from mlx.data.datasets import load_cifar10
|
||||
from model import CLIPModel
|
||||
from PIL import Image
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root).batch(batch_size)
|
||||
test = load_cifar10(root=root, train=False).batch(batch_size)
|
||||
|
||||
return tr, test
|
||||
|
||||
|
||||
def get_features(model, image_proc, iter):
|
||||
all_features = []
|
||||
all_labels = []
|
||||
|
||||
for batch in tqdm(iter):
|
||||
image, label = batch["image"], batch["label"]
|
||||
x = image_proc([Image.fromarray(im) for im in image])
|
||||
y = mx.array(label)
|
||||
|
||||
image_embeds = model.get_image_features(x)
|
||||
mx.eval(image_embeds)
|
||||
|
||||
all_features.append(image_embeds)
|
||||
all_labels.append(y)
|
||||
|
||||
return mx.concatenate(all_features), mx.concatenate(all_labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = CLIPModel.from_pretrained("mlx_model")
|
||||
image_proc = CLIPImageProcessor.from_pretrained("mlx_model")
|
||||
|
||||
train_iter, test_iter = get_cifar10(batch_size=256)
|
||||
train_features, train_labels = get_features(model, image_proc, train_iter)
|
||||
test_features, test_labels = get_features(model, image_proc, test_iter)
|
||||
|
||||
# Perform logistic regression
|
||||
# NOTE: The value of C should be determined via a hyperparameter sweep
|
||||
# using a validation split
|
||||
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
|
||||
classifier.fit(train_features, train_labels)
|
||||
|
||||
# Evaluate using the logistic regression classifier
|
||||
predictions = classifier.predict(test_features)
|
||||
accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
|
||||
print(f"Accuracy = {accuracy:.3f}")
|
@ -1,4 +1,5 @@
|
||||
mlx
|
||||
mlx-data
|
||||
numpy
|
||||
transformers
|
||||
torch
|
||||
|
@ -222,6 +222,17 @@ data formats. Here are examples of these formats:
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
The format for the `arguments` field in a function varies for different models.
|
||||
Common formats include JSON strings and dictionaries. The example provided
|
||||
follows the format used by
|
||||
[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples)
|
||||
and [Mistral
|
||||
AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct).
|
||||
A dictionary format is used in Hugging Face's [chat
|
||||
templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example).
|
||||
Refer to the documentation for the model you are fine-tuning for more details.
|
||||
|
||||
</details>
|
||||
|
||||
`completions`:
|
||||
@ -241,7 +252,7 @@ each line not expected by the loader will be ignored.
|
||||
|
||||
> [!NOTE]
|
||||
> Each example in the datasets must be on a single line. Do not put more than
|
||||
> one example per line and do not split an example accross multiple lines.
|
||||
> one example per line and do not split an example across multiple lines.
|
||||
|
||||
### Hugging Face Datasets
|
||||
|
||||
|
@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the parameters, ignored if -q is given.",
|
||||
help="Type to save the non-quantized parameters.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
|
@ -111,7 +111,7 @@ class MLP(nn.Module):
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
@ -205,7 +205,7 @@ class Model(nn.Module):
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv1d.weight" in k and v.ndim == 3:
|
||||
if "conv1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
|
@ -440,7 +440,7 @@ class Model(nn.Module):
|
||||
|
||||
def sanitize(self, weights):
|
||||
for k, v in weights.items():
|
||||
if "conv_1d.weight" in k and v.ndim == 3:
|
||||
if "conv_1d.weight" in k and v.shape[-1] != 1:
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
if "lm_head.weight" not in weights:
|
||||
self.pop("lm_head")
|
||||
|
@ -720,7 +720,7 @@ def convert(
|
||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
||||
dtype = getattr(mx, dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
|
||||
if quantize and dequantize:
|
||||
|
Loading…
Reference in New Issue
Block a user