mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Mixtral: Fix non-default arg follows default exception (#450)
Mixtral models throw the following exception ``` Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 119, in <module> main(args) File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/generate.py", line 96, in main model, tokenizer = load(args.model, tokenizer_config=tokenizer_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 278, in load model = load_model(model_path) ^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 221, in load_model model_class, model_args_class = _get_classes(config=config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/utils.py", line 46, in _get_classes arch = importlib.import_module(f"mlx_lm.models.{model_type}") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/importlib/__init__.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<frozen importlib._bootstrap>", line 1204, in _gcd_import File "<frozen importlib._bootstrap>", line 1176, in _find_and_load File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 690, in _load_unlocked File "<frozen importlib._bootstrap_external>", line 940, in exec_module File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/mlx_lm/models/mixtral.py", line 11, in <module> @dataclass ^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1230, in dataclass return wrap(cls) ^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1220, in wrap return _process_class(cls, init, repr, eq, order, unsafe_hash, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 1027, in _process_class _init_fn(all_init_fields, File "/opt/homebrew/anaconda3/lib/python3.11/dataclasses.py", line 545, in _init_fn raise TypeError(f'non-default argument {f.name!r} ' TypeError: non-default argument 'model_type' follows default argument ```
This commit is contained in:
parent
b05907c87e
commit
cc671cd1c7
@ -10,6 +10,8 @@ from .base import BaseModelArgs
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
vocab_size: int
|
||||||
vocab_size: int = 32000
|
vocab_size: int = 32000
|
||||||
max_position_embeddings: int = 4096 * 32
|
max_position_embeddings: int = 4096 * 32
|
||||||
hidden_size: int = 4096
|
hidden_size: int = 4096
|
||||||
@ -20,8 +22,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_key_value_heads: int = 8
|
num_key_value_heads: int = 8
|
||||||
num_local_experts: int = 8
|
num_local_experts: int = 8
|
||||||
rms_norm_eps: float = 1e-5
|
rms_norm_eps: float = 1e-5
|
||||||
vocab_size: int
|
|
||||||
model_type: str
|
|
||||||
rope_theta: float = 1e6
|
rope_theta: float = 1e6
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
Loading…
Reference in New Issue
Block a user