mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import models
|
||||
import utils
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
@@ -12,11 +13,8 @@ from mlx.utils import tree_flatten
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Get model classes
|
||||
model_class, model_args_class = utils._get_classes(config=config)
|
||||
|
||||
# Load the model:
|
||||
model = model_class(model_args_class.from_dict(config))
|
||||
model = models.Model(models.ModelArgs.from_dict(config))
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
# Quantize the model:
|
||||
|
Reference in New Issue
Block a user