mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-13 21:06:38 +08:00
generation works but outputs gibberish
This commit is contained in:
parent
4ab5139c05
commit
ab4cf1d1cf
@ -338,30 +338,3 @@ class MambaCache(_BaseCache):
|
|||||||
@state.setter
|
@state.setter
|
||||||
def state(self, v):
|
def state(self, v):
|
||||||
self.cache = v
|
self.cache = v
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Cache:
|
|
||||||
def __init__(self, num_layers):
|
|
||||||
self.conv_states = [None] * num_layers
|
|
||||||
self.ssm_states = [None] * num_layers
|
|
||||||
self.seqlen_offset = 0
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return (self.conv_states[idx], self.ssm_states[idx])
|
|
||||||
|
|
||||||
def __setitem__(self, idx, value):
|
|
||||||
self.conv_states[idx], self.ssm_states[idx] = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self):
|
|
||||||
return {
|
|
||||||
'conv_states': self.conv_states,
|
|
||||||
'ssm_states': self.ssm_states,
|
|
||||||
'seqlen_offset': self.seqlen_offset
|
|
||||||
}
|
|
||||||
|
|
||||||
@state.setter
|
|
||||||
def state(self, v):
|
|
||||||
self.conv_states = v['conv_states']
|
|
||||||
self.ssm_states = v['ssm_states']
|
|
||||||
self.seqlen_offset = v['seqlen_offset']
|
|
@ -1,16 +1,11 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Tuple, Union, Optional
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
from .cache import Mamba2Cache
|
from .cache import MambaCache
|
||||||
|
|
||||||
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
@ -48,6 +43,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
if self.time_step_rank == "auto":
|
if self.time_step_rank == "auto":
|
||||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
|
||||||
|
|
||||||
class MambaRMSNormGated(nn.Module):
|
class MambaRMSNormGated(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -97,175 +93,108 @@ class DepthWiseConv1d(nn.Module):
|
|||||||
return y, x[:, -K + 1 :, :]
|
return y, x[:, -K + 1 :, :]
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Mixer(nn.Module):
|
class Mamba2Block(nn.Module):
|
||||||
def __init__(self, args, layer_idx):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.args = args
|
||||||
self.hidden_size = args.hidden_size
|
|
||||||
self.intermediate_size = args.intermediate_size
|
self.intermediate_size = args.intermediate_size
|
||||||
self.num_heads = args.num_heads
|
self.time_step_rank = args.time_step_rank
|
||||||
self.head_dim = args.head_dim
|
|
||||||
self.ssm_state_size = args.state_size
|
|
||||||
self.n_groups = args.n_groups
|
|
||||||
self.conv_kernel_size = args.conv_kernel
|
self.conv_kernel_size = args.conv_kernel
|
||||||
self.use_conv_bias = args.use_conv_bias
|
self.hidden_size = args.hidden_size
|
||||||
self.use_bias = args.use_bias
|
self.state_size = args.state_size
|
||||||
self.time_step_min = args.time_step_min
|
self.num_heads = args.num_heads
|
||||||
self.time_step_max = args.time_step_max
|
self.head_dim = args.hidden_size // args.num_heads
|
||||||
self.chunk_size = args.chunk_size
|
self.n_groups = args.n_groups
|
||||||
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
|
|
||||||
|
|
||||||
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
|
self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
|
||||||
|
self.conv1d = DepthWiseConv1d(
|
||||||
|
in_channels=self.conv_dim,
|
||||||
|
out_channels=self.conv_dim,
|
||||||
|
kernel_size=args.conv_kernel,
|
||||||
|
bias=args.use_conv_bias,
|
||||||
|
groups=self.conv_dim,
|
||||||
|
padding=args.conv_kernel - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
projection_size = args.intermediate_size + self.conv_dim + args.num_heads
|
||||||
self.in_proj = nn.Linear(
|
self.in_proj = nn.Linear(
|
||||||
self.hidden_size,
|
args.hidden_size,
|
||||||
projection_size,
|
projection_size,
|
||||||
bias=args.use_bias
|
bias=args.use_bias
|
||||||
)
|
)
|
||||||
self.conv1d = nn.Conv1d(
|
|
||||||
self.conv_dim,
|
|
||||||
self.conv_dim,
|
|
||||||
self.conv_kernel_size,
|
|
||||||
groups=self.conv_dim,
|
|
||||||
bias=self.use_conv_bias
|
|
||||||
)
|
|
||||||
self.act = nn.SiLU()
|
self.act = nn.SiLU()
|
||||||
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
|
|
||||||
self.out_proj = nn.Linear(
|
self.A_log = mx.zeros(args.num_heads)
|
||||||
self.intermediate_size,
|
self.D = mx.ones((args.num_heads,))
|
||||||
self.hidden_size,
|
self.dt_bias = mx.zeros(args.num_heads)
|
||||||
bias=self.use_bias
|
|
||||||
|
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
|
||||||
|
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def ssm_step(self, x, state, dt_proj):
|
||||||
|
A = -mx.exp(self.A_log)
|
||||||
|
D = self.D
|
||||||
|
delta = nn.softplus(dt_proj + self.dt_bias)
|
||||||
|
|
||||||
|
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
|
||||||
|
|
||||||
|
batch_size = B.shape[0]
|
||||||
|
B = B.reshape(batch_size, self.n_groups, self.state_size)
|
||||||
|
C = C.reshape(batch_size, -1, self.state_size)
|
||||||
|
|
||||||
|
delta = delta.reshape(batch_size, self.num_heads, 1)
|
||||||
|
A = A.reshape(1, self.num_heads, 1)
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
new_state = delta * B
|
||||||
|
else:
|
||||||
|
new_state = delta * (B + state * mx.exp(delta * A))
|
||||||
|
|
||||||
|
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
|
||||||
|
y = y + D * x[:, :self.num_heads]
|
||||||
|
return y, new_state
|
||||||
|
|
||||||
|
def __call__(self, x, cache):
|
||||||
|
B, T, D = x.shape
|
||||||
|
if cache is None:
|
||||||
|
cache = [None, None]
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for t in range(T):
|
||||||
|
xt = x[:, t, :]
|
||||||
|
xz = self.in_proj(xt)
|
||||||
|
|
||||||
|
x_t, z_t, dt_proj = mx.split(
|
||||||
|
xz,
|
||||||
|
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
|
||||||
|
axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
self.A_log = mx.zeros(self.num_heads)
|
# Use the new DepthWiseConv1d with caching
|
||||||
self.D = mx.ones(self.num_heads)
|
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
||||||
self.dt_bias = mx.zeros(self.num_heads)
|
x_t = conv_out.squeeze(1)
|
||||||
|
x_t = nn.silu(x_t)
|
||||||
|
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
|
||||||
|
z_t = nn.silu(z_t)
|
||||||
|
|
||||||
def __call__(self, input_states, cache):
|
# Element-wise multiplication
|
||||||
batch_size, seq_len, _ = input_states.shape
|
output_t = y_t[:, :, None] * z_t[:, None, :]
|
||||||
dtype = input_states.dtype
|
|
||||||
|
|
||||||
projected_states = self.in_proj(input_states)
|
# Sum across the second dimension to match the intermediate_size
|
||||||
|
output_t = output_t.sum(axis=1)
|
||||||
|
|
||||||
# Calculate the sizes of each split
|
output_t = self.out_proj(output_t)
|
||||||
total_size = projected_states.shape[-1]
|
outputs.append(output_t)
|
||||||
remaining_size = total_size - self.intermediate_size - self.conv_dim - self.num_heads
|
|
||||||
d_mlp = remaining_size // 2
|
|
||||||
sizes = [
|
|
||||||
d_mlp,
|
|
||||||
d_mlp,
|
|
||||||
self.intermediate_size,
|
|
||||||
self.conv_dim,
|
|
||||||
self.num_heads
|
|
||||||
]
|
|
||||||
|
|
||||||
# Perform the split operation
|
output = mx.stack(outputs, axis=1)
|
||||||
split_result = mx.split(projected_states, sizes, axis=-1)
|
return output
|
||||||
|
|
||||||
# Print debug information
|
|
||||||
print(f"Number of split parts: {len(split_result)}")
|
|
||||||
print(f"Shapes of split parts: {[part.shape for part in split_result]}")
|
|
||||||
|
|
||||||
# Flexibly handle the split result
|
|
||||||
_, _, _, gate, hidden_states, dt = split_result
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
conv_state = cache.conv_states[self.layer_idx]
|
|
||||||
if conv_state is None:
|
|
||||||
# Initialize conv_state if it's None
|
|
||||||
conv_state = mx.zeros((batch_size, 1, self.conv_kernel_size, hidden_states.shape[-1]))
|
|
||||||
|
|
||||||
conv_state = mx.roll(conv_state, -1, -2) # Roll along the kernel dimension
|
|
||||||
|
|
||||||
# Reshape hidden_states to match conv_state dimensions
|
|
||||||
hidden_states_reshaped = hidden_states[:, None, None, :]
|
|
||||||
|
|
||||||
conv_state = mx.concat([conv_state[:, :, :-1, :], hidden_states_reshaped], axis=-2)
|
|
||||||
cache.conv_states[self.layer_idx] = conv_state
|
|
||||||
|
|
||||||
# Adjust the convolution operation
|
|
||||||
hidden_states = mx.sum(conv_state * self.conv1d.weight[:, :, None, :], axis=(-2, -1))
|
|
||||||
|
|
||||||
if self.use_conv_bias:
|
|
||||||
hidden_states += self.conv1d.bias
|
|
||||||
hidden_states = self.act(hidden_states)[:, None, :]
|
|
||||||
else:
|
|
||||||
hidden_states = hidden_states.transpose(0, 2, 1)
|
|
||||||
hidden_states = self.act(self.conv1d(hidden_states)).transpose(0, 2, 1)
|
|
||||||
|
|
||||||
hidden_states, B, C = mx.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], axis=-1)
|
|
||||||
|
|
||||||
A = -mx.exp(self.A_log.astype(mx.float32))
|
|
||||||
dt = nn.softplus(dt + self.dt_bias)
|
|
||||||
dt = mx.clip(dt, self.time_step_min, self.time_step_max)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).astype(mx.float32)
|
|
||||||
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
|
|
||||||
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
|
|
||||||
|
|
||||||
B = mx.repeat(B, repeats=self.num_heads // self.n_groups, axis=2)
|
|
||||||
C = mx.repeat(C, repeats=self.num_heads // self.n_groups, axis=2)
|
|
||||||
|
|
||||||
if cache is not None and cache.seqlen_offset > 0:
|
|
||||||
ssm_state = cache.ssm_states[self.layer_idx]
|
|
||||||
dA = mx.exp(dt[:, None, :, None] * A[None, :, None, None])
|
|
||||||
dB = dt[:, None, :, None] * B
|
|
||||||
dBx = dB * hidden_states[:, :, :, None]
|
|
||||||
ssm_state = ssm_state * dA + dBx
|
|
||||||
cache.ssm_states[self.layer_idx] = ssm_state
|
|
||||||
|
|
||||||
y = mx.sum(ssm_state * C[:, None, :, :], axis=-1)
|
|
||||||
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
|
|
||||||
y = y + hidden_states * D
|
|
||||||
|
|
||||||
y = y.reshape(batch_size, -1)[:, None, :]
|
|
||||||
else:
|
|
||||||
# Implement chunked computation here (simplified version)
|
|
||||||
pad_size = self.chunk_size - (seq_len % self.chunk_size)
|
|
||||||
hidden_states_padded = mx.pad(hidden_states, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
|
|
||||||
B_padded = mx.pad(B, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
|
|
||||||
C_padded = mx.pad(C, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
|
|
||||||
|
|
||||||
chunks = seq_len // self.chunk_size + (1 if pad_size > 0 else 0)
|
|
||||||
y_list = []
|
|
||||||
ssm_state = mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
|
|
||||||
|
|
||||||
for i in range(chunks):
|
|
||||||
chunk_start = i * self.chunk_size
|
|
||||||
chunk_end = (i + 1) * self.chunk_size
|
|
||||||
chunk_h = hidden_states_padded[:, chunk_start:chunk_end]
|
|
||||||
chunk_B = B_padded[:, chunk_start:chunk_end]
|
|
||||||
chunk_C = C_padded[:, chunk_start:chunk_end]
|
|
||||||
|
|
||||||
chunk_dt = dt[:, chunk_start:chunk_end]
|
|
||||||
dA = mx.exp(chunk_dt[:, :, None, None] * A[None, None, :, None])
|
|
||||||
dB = chunk_dt[:, :, None, None] * chunk_B
|
|
||||||
dBx = dB * chunk_h[:, :, :, None]
|
|
||||||
|
|
||||||
chunk_y = mx.zeros_like(chunk_h)
|
|
||||||
for j in range(self.chunk_size):
|
|
||||||
ssm_state = ssm_state * dA[:, j] + dBx[:, j]
|
|
||||||
chunk_y[:, j] = mx.sum(ssm_state * chunk_C[:, j], axis=-1)
|
|
||||||
|
|
||||||
y_list.append(chunk_y)
|
|
||||||
|
|
||||||
y = mx.concat(y_list, axis=1)
|
|
||||||
if pad_size > 0:
|
|
||||||
y = y[:, :seq_len]
|
|
||||||
|
|
||||||
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
|
|
||||||
y = y + hidden_states * D
|
|
||||||
y = y.reshape(batch_size, seq_len, -1)
|
|
||||||
|
|
||||||
y = self.norm(y, gate)
|
|
||||||
contextualized_states = self.out_proj(y.astype(dtype))
|
|
||||||
|
|
||||||
return contextualized_states
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Block(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, args: ModelArgs, layer_idx: int):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mixer = Mamba2Mixer(args, layer_idx)
|
self.mixer = Mamba2Block(args)
|
||||||
self.norm = nn.RMSNorm(args.hidden_size)
|
self.norm = nn.RMSNorm(args.hidden_size)
|
||||||
|
|
||||||
def __call__(self, x: mx.array, cache):
|
def __call__(self, x: mx.array, cache):
|
||||||
@ -277,24 +206,16 @@ class Mamba2(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
self.layers = [Mamba2Block(args, idx) for idx in range(args.num_hidden_layers)]
|
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
||||||
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, x: mx.array, cache):
|
||||||
self,
|
x = self.embeddings(x)
|
||||||
inputs: mx.array,
|
|
||||||
cache=None
|
|
||||||
):
|
|
||||||
hidden_states = self.embeddings(inputs)
|
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = Mamba2Cache(len(self.layers))
|
cache = [None] * len(self.layers)
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
for i, layer in enumerate(self.layers):
|
x = layer(x, c)
|
||||||
hidden_states = layer(hidden_states, cache[i])
|
return self.norm_f(x)
|
||||||
|
|
||||||
hidden_states = self.norm_f(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -302,7 +223,10 @@ class Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
|
|
||||||
self.backbone = Mamba2(args)
|
self.backbone = Mamba2(args)
|
||||||
|
# self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||||
|
|
||||||
if not args.tie_word_embeddings:
|
if not args.tie_word_embeddings:
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
@ -316,9 +240,6 @@ class Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
logits = self.lm_head(x)
|
logits = self.lm_head(x)
|
||||||
|
|
||||||
print(logits)
|
|
||||||
print(logits.shape)
|
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
@ -328,7 +249,7 @@ class Model(nn.Module):
|
|||||||
return weights
|
return weights
|
||||||
|
|
||||||
def make_cache(self):
|
def make_cache(self):
|
||||||
return [Mamba2Cache(self.args.num_hidden_layers) for _ in range(len(self.layers))]
|
return [MambaCache() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user