mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-30 21:31:14 +08:00
adding the modelnames in the LORA.md file and removing unused functions from mamba2.py
This commit is contained in:
parent
a883e39f41
commit
dff4e52910
@ -7,12 +7,37 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
|||||||
- Mistral
|
- Mistral
|
||||||
- Llama
|
- Llama
|
||||||
- Phi2
|
- Phi2
|
||||||
|
- Phi3
|
||||||
|
- Phi3 Small
|
||||||
|
- PhiMOE
|
||||||
|
- Phixtral
|
||||||
|
- Plamo
|
||||||
- Mixtral
|
- Mixtral
|
||||||
|
- Qwen
|
||||||
- Qwen2
|
- Qwen2
|
||||||
|
- Qwen2 MOE
|
||||||
- Gemma
|
- Gemma
|
||||||
|
- Gemma2
|
||||||
- OLMo
|
- OLMo
|
||||||
|
- OLMo2
|
||||||
- MiniCPM
|
- MiniCPM
|
||||||
- InternLM2
|
- InternLM2
|
||||||
|
- Mamba
|
||||||
|
- Mamba2
|
||||||
|
- EXAONE
|
||||||
|
- Hunyuan
|
||||||
|
- GPT 2
|
||||||
|
- GPT Neo
|
||||||
|
- GPT BigCode
|
||||||
|
- Deepseek
|
||||||
|
- Deepseek2
|
||||||
|
- OpenLM
|
||||||
|
- StableLM
|
||||||
|
- Cohere
|
||||||
|
- DBRX
|
||||||
|
- Nemotron
|
||||||
|
- Recurrent Gemma
|
||||||
|
- Starcoder
|
||||||
|
|
||||||
## Contents
|
## Contents
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from .cache import MambaCache
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
num_heads: int
|
num_heads: int
|
||||||
head_dim: int
|
head_dim: int
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
@ -30,15 +31,16 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rms_norm: bool
|
rms_norm: bool
|
||||||
chunk_size: int
|
chunk_size: int
|
||||||
tie_word_embeddings: bool
|
tie_word_embeddings: bool
|
||||||
|
dim: int = None
|
||||||
intermediate_size: int = None
|
intermediate_size: int = None
|
||||||
use_cache: bool = True
|
|
||||||
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
|
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
|
||||||
time_step_rank: Union[int, str] = "auto"
|
time_step_rank: Union[int, str] = "auto"
|
||||||
model_type: str = "mamba2"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not hasattr(self, "intermediate_size"):
|
if not hasattr(self, "intermediate_size"):
|
||||||
self.intermediate_size = int(self.expand * self.hidden_size)
|
self.intermediate_size = int(self.expand * self.hidden_size)
|
||||||
|
if not hasattr(self, "hidden_size"):
|
||||||
|
self.hidden_size = self.dim
|
||||||
if not hasattr(self, "head_dim"):
|
if not hasattr(self, "head_dim"):
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
if self.time_step_rank == "auto":
|
if self.time_step_rank == "auto":
|
||||||
@ -63,34 +65,6 @@ def silu(x):
|
|||||||
return x * mx.sigmoid(x)
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
def ssd(x, A, B, C, chunk_size):
|
|
||||||
batch, seqlen, nheads, dim = x.shape
|
|
||||||
|
|
||||||
B = mx.expand_dims(B, axis=2)
|
|
||||||
C = mx.expand_dims(C, axis=2)
|
|
||||||
|
|
||||||
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
for i in range(0, seqlen, chunk_size):
|
|
||||||
chunk = slice(i, min(i + chunk_size, seqlen))
|
|
||||||
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
|
|
||||||
|
|
||||||
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
|
|
||||||
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
|
|
||||||
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
|
|
||||||
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
|
|
||||||
|
|
||||||
state = state * mx.expand_dims(dA, axis=-1) + dBx
|
|
||||||
|
|
||||||
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
|
|
||||||
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
|
|
||||||
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
|
|
||||||
outputs.append(y)
|
|
||||||
|
|
||||||
return mx.concatenate(outputs, axis=1), state
|
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user