mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 12:51:12 +08:00
adding the modelnames in the LORA.md file and removing unused functions from mamba2.py
This commit is contained in:
parent
a883e39f41
commit
dff4e52910
@ -7,12 +7,37 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||
- Mistral
|
||||
- Llama
|
||||
- Phi2
|
||||
- Phi3
|
||||
- Phi3 Small
|
||||
- PhiMOE
|
||||
- Phixtral
|
||||
- Plamo
|
||||
- Mixtral
|
||||
- Qwen
|
||||
- Qwen2
|
||||
- Qwen2 MOE
|
||||
- Gemma
|
||||
- Gemma2
|
||||
- OLMo
|
||||
- OLMo2
|
||||
- MiniCPM
|
||||
- InternLM2
|
||||
- Mamba
|
||||
- Mamba2
|
||||
- EXAONE
|
||||
- Hunyuan
|
||||
- GPT 2
|
||||
- GPT Neo
|
||||
- GPT BigCode
|
||||
- Deepseek
|
||||
- Deepseek2
|
||||
- OpenLM
|
||||
- StableLM
|
||||
- Cohere
|
||||
- DBRX
|
||||
- Nemotron
|
||||
- Recurrent Gemma
|
||||
- Starcoder
|
||||
|
||||
## Contents
|
||||
|
||||
|
@ -9,6 +9,7 @@ from .cache import MambaCache
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
vocab_size: int
|
||||
@ -30,15 +31,16 @@ class ModelArgs(BaseModelArgs):
|
||||
rms_norm: bool
|
||||
chunk_size: int
|
||||
tie_word_embeddings: bool
|
||||
dim: int = None
|
||||
intermediate_size: int = None
|
||||
use_cache: bool = True
|
||||
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
|
||||
time_step_rank: Union[int, str] = "auto"
|
||||
model_type: str = "mamba2"
|
||||
|
||||
def __post_init__(self):
|
||||
if not hasattr(self, "intermediate_size"):
|
||||
self.intermediate_size = int(self.expand * self.hidden_size)
|
||||
if not hasattr(self, "hidden_size"):
|
||||
self.hidden_size = self.dim
|
||||
if not hasattr(self, "head_dim"):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
if self.time_step_rank == "auto":
|
||||
@ -63,34 +65,6 @@ def silu(x):
|
||||
return x * mx.sigmoid(x)
|
||||
|
||||
|
||||
def ssd(x, A, B, C, chunk_size):
|
||||
batch, seqlen, nheads, dim = x.shape
|
||||
|
||||
B = mx.expand_dims(B, axis=2)
|
||||
C = mx.expand_dims(C, axis=2)
|
||||
|
||||
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
|
||||
outputs = []
|
||||
|
||||
for i in range(0, seqlen, chunk_size):
|
||||
chunk = slice(i, min(i + chunk_size, seqlen))
|
||||
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
|
||||
|
||||
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
|
||||
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
|
||||
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
|
||||
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
|
||||
|
||||
state = state * mx.expand_dims(dA, axis=-1) + dBx
|
||||
|
||||
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
|
||||
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
|
||||
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
|
||||
outputs.append(y)
|
||||
|
||||
return mx.concatenate(outputs, axis=1), state
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user