Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun
2024-03-23 07:13:51 -07:00
committed by GitHub
parent fbed720d6f
commit b8a348c1b8
44 changed files with 144 additions and 1155 deletions

View File

@@ -1,10 +1,11 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import argparse
import copy
import mlx.core as mx
import mlx.nn as nn
import models
import utils
from mlx.utils import tree_flatten
@@ -12,11 +13,8 @@ from mlx.utils import tree_flatten
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Get model classes
model_class, model_args_class = utils._get_classes(config=config)
# Load the model:
model = model_class(model_args_class.from_dict(config))
model = models.Model(models.ModelArgs.from_dict(config))
model.load_weights(list(weights.items()))
# Quantize the model: