From cc671cd1c7334910b7c7fe44ea5970c321433d0d Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Sun, 18 Feb 2024 13:30:26 -0800 Subject: [PATCH] Mixtral: Fix non-default arg follows default exception (#450) Mixtral models throw the following exception ``` Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 119, in main(args) File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 96, in main model, tokenizer = load(args.model, tokenizer_config=tokenizer_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 278, in load model = load_model(model_path) ^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 221, in load_model model_class, model_args_class = _get_classes(config=config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 46, in _get_classes arch = importlib.import_module(f"mlx_lm.models.{model_type}") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/importlib/__init__.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "", line 1204, in _gcd_import File "", line 1176, in _find_and_load File "", line 1147, in _find_and_load_unlocked File "", line 690, in _load_unlocked File "", line 940, in exec_module File "", line 241, in _call_with_frames_removed File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/models/mixtral.py", line 11, in @dataclass ^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1230, in dataclass return wrap(cls) ^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1220, in wrap return _process_class(cls, init, repr, eq, order, unsafe_hash, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1027, in _process_class _init_fn(all_init_fields, File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 545, in _init_fn raise TypeError(f'non-default argument {f.name!r} ' TypeError: non-default argument 'model_type' follows default argument ``` --- llms/mlx_lm/models/mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 5b4875eb..a2ff0d06 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -10,6 +10,8 @@ from .base import BaseModelArgs @dataclass class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int vocab_size: int = 32000 max_position_embeddings: int = 4096 * 32 hidden_size: int = 4096 @@ -20,8 +22,6 @@ class ModelArgs(BaseModelArgs): num_key_value_heads: int = 8 num_local_experts: int = 8 rms_norm_eps: float = 1e-5 - vocab_size: int - model_type: str rope_theta: float = 1e6 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None