This commit is contained in:
Awni Hannun
2024-01-03 15:01:02 -08:00
parent 99581115a0
commit f0aaab7d91

View File

@@ -10,7 +10,7 @@ import mlx.core as mx
import mlx.nn as nn
import transformers
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_map
from mlx.utils import tree_flatten
from models import Model, ModelArgs
@@ -39,7 +39,6 @@ def quantize(weights, config, args):
# Load the model:
model = Model(ModelArgs.from_dict(config))
weights = tree_map(mx.array, weights)
model.load_weights(list(weights.items()))
# Quantize the model: