From dff4e529100db808a251b01dedaa9342287929c2 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 12 Dec 2024 22:52:00 +0100 Subject: [PATCH] adding the modelnames in the LORA.md file and removing unused functions from mamba2.py --- llms/mlx_lm/LORA.md | 25 +++++++++++++++++++++++++ llms/mlx_lm/models/mamba2.py | 34 ++++------------------------------ 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 15676360..6e4286d1 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -7,12 +7,37 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - Mistral - Llama - Phi2 +- Phi3 +- Phi3 Small +- PhiMOE +- Phixtral +- Plamo - Mixtral +- Qwen - Qwen2 +- Qwen2 MOE - Gemma +- Gemma2 - OLMo +- OLMo2 - MiniCPM - InternLM2 +- Mamba +- Mamba2 +- EXAONE +- Hunyuan +- GPT 2 +- GPT Neo +- GPT BigCode +- Deepseek +- Deepseek2 +- OpenLM +- StableLM +- Cohere +- DBRX +- Nemotron +- Recurrent Gemma +- Starcoder ## Contents diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 0ed62287..c3edf1b7 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -9,6 +9,7 @@ from .cache import MambaCache @dataclass class ModelArgs(BaseModelArgs): + model_type: str num_heads: int head_dim: int vocab_size: int @@ -30,15 +31,16 @@ class ModelArgs(BaseModelArgs): rms_norm: bool chunk_size: int tie_word_embeddings: bool + dim: int = None intermediate_size: int = None - use_cache: bool = True time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_rank: Union[int, str] = "auto" - model_type: str = "mamba2" def __post_init__(self): if not hasattr(self, "intermediate_size"): self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "hidden_size"): + self.hidden_size = self.dim if not hasattr(self, "head_dim"): self.head_dim = self.hidden_size // self.num_heads if self.time_step_rank == "auto": @@ -63,34 +65,6 @@ def silu(x): return x * mx.sigmoid(x) -def ssd(x, A, B, C, chunk_size): - batch, seqlen, nheads, dim = x.shape - - B = mx.expand_dims(B, axis=2) - C = mx.expand_dims(C, axis=2) - - state = mx.zeros((batch, nheads, dim, B.shape[-1])) - outputs = [] - - for i in range(0, seqlen, chunk_size): - chunk = slice(i, min(i + chunk_size, seqlen)) - dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - - x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] - x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] - B_chunk = B[:, chunk] # [batch, chunk_size, state_size] - dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] - - state = state * mx.expand_dims(dA, axis=-1) + dBx - - C_chunk = C[:, chunk] # [batch, chunk_size, state_size] - y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] - y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] - outputs.append(y) - - return mx.concatenate(outputs, axis=1), state - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__()