Lazy import + refactor Lora layer addition (#426)

* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
This commit is contained in:
Awni Hannun
2024-02-12 10:51:02 -08:00
committed by GitHub
parent 4576946151
commit d4666615bb
15 changed files with 127 additions and 393 deletions

View File

@@ -9,6 +9,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
@@ -18,7 +19,6 @@ class ModelArgs(BaseModelArgs):
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@@ -190,6 +190,7 @@ class LlamaModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)