Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun
2024-03-23 07:13:51 -07:00
committed by GitHub
parent fbed720d6f
commit b8a348c1b8
44 changed files with 144 additions and 1155 deletions

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import glob
import json
@@ -8,40 +8,10 @@ from typing import Generator
import mlx.core as mx
import mlx.nn as nn
import models.llama as llama
import models.mixtral as mixtral
import models.phi2 as phi2
import models
import transformers
from huggingface_hub import snapshot_download
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"phi": phi2,
"mixtral": mixtral,
}
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
if model_type not in MODEL_MAPPING:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
arch = MODEL_MAPPING[model_type]
return arch.Model, arch.ModelArgs
def fetch_from_hub(hf_path: str):
model_path = snapshot_download(
@@ -157,9 +127,8 @@ def load(path_or_hf_repo: str):
for wf in weight_files:
weights.update(mx.load(wf).items())
model_class, model_args_class = _get_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
model_args = models.ModelArgs.from_dict(config)
model = models.Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(
model,