mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import glob
|
||||
import json
|
||||
@@ -8,40 +8,10 @@ from typing import Generator
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import models.llama as llama
|
||||
import models.mixtral as mixtral
|
||||
import models.phi2 as phi2
|
||||
import models
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Constants
|
||||
MODEL_MAPPING = {
|
||||
"llama": llama,
|
||||
"mistral": llama, # mistral is compatible with llama
|
||||
"phi": phi2,
|
||||
"mixtral": mixtral,
|
||||
}
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
if model_type not in MODEL_MAPPING:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
logging.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
arch = MODEL_MAPPING[model_type]
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
def fetch_from_hub(hf_path: str):
|
||||
model_path = snapshot_download(
|
||||
@@ -157,9 +127,8 @@ def load(path_or_hf_repo: str):
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
model_args = models.ModelArgs.from_dict(config)
|
||||
model = models.Model(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
|
||||
Reference in New Issue
Block a user