From 1a6688384d1ec34f88a83ad49ca0a6746e5bdaad Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 10 Nov 2024 17:19:00 +0100 Subject: [PATCH] imopemented multi Token inputs, but still generating Gibberish --- llms/mlx_lm/models/mamba2.py | 368 +++++++++++++++-------------------- 1 file changed, 157 insertions(+), 211 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 1cf493dc..8ea641f4 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -91,168 +91,6 @@ def ssd(x, A, B, C, chunk_size): return mx.concatenate(outputs, axis=1), state -# class DepthWiseConv1d(nn.Module): -# def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): -# super().__init__() -# self.in_channels = in_channels -# self.out_channels = out_channels -# self.kernel_size = kernel_size -# self.padding = padding -# self.groups = groups if groups is not None else in_channels - -# assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" -# assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - -# self.weight = mx.random.normal((in_channels, 1, kernel_size)) -# self.bias = mx.zeros((out_channels,)) if bias else None - -# def __call__(self, x: mx.array, cache=None) -> mx.array: -# B, L, C = x.shape -# K = self.kernel_size - -# assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" - -# if cache is not None: -# if isinstance(cache.conv_states[0], type(None)): -# cache.conv_states[0] = mx.zeros((B, K-1, C)) - -# x = mx.concatenate([cache.conv_states[0], x], axis=1) - -# outputs = [] -# for c in range(C): -# # Input prep debug -# x_c = x[:, :, c] -# x_c = mx.expand_dims(x_c, axis=1) - -# # Weight prep debug -# w_c = self.weight[c] -# if w_c.ndim == 2: -# w_c = mx.expand_dims(w_c, axis=0) -# elif w_c.ndim == 1: -# w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) - -# y_c = mx.conv_general( -# x_c, -# w_c, -# stride=1, -# padding=0 -# ) -# if self.bias is not None: -# y_c = y_c + self.bias[c] - -# y_c = mx.squeeze(y_c, axis=1) -# outputs.append(y_c) - -# # Output statistics -# y = mx.stack(outputs, axis=-1) - -# # Cache update debug -# if cache is not None: -# cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x - -# return y - - -# class Mamba2Block(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.args = args - -# d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads -# self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) - -# conv_dim = args.intermediate_size + 2 * args.state_size -# self.conv1d = DepthWiseConv1d( -# in_channels=conv_dim, -# out_channels=conv_dim, -# kernel_size=args.conv_kernel, -# groups=conv_dim, -# bias=args.use_conv_bias, -# padding=args.conv_kernel - 1 -# ) - -# self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range -# self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range -# self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - -# self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) -# self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - -# if args.rescale_prenorm_residual: -# layer_scale = math.sqrt(1.0 / args.num_hidden_layers) -# self.out_proj.weight = self.out_proj.weight * layer_scale - -# def __call__(self, u: mx.array, cache=None): -# # Expect input to be shape [batch_size, 1, dim] -# batch_size, seq_len, dimension = u.shape -# assert seq_len == 1, "Input should be a single token" - -# # Initialize cache if needed -# if cache.conv_states[0] is None: -# conv_dim = self.args.intermediate_size + 2 * self.args.state_size -# cache.conv_states[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) - -# if cache.ssm_states[0] is None: -# cache.ssm_states[0] = mx.zeros(( -# batch_size, -# self.args.num_heads, -# self.args.head_dim, -# self.args.state_size -# )) - -# # Project input -# zxbcdt = self.in_proj(u) - -# # Split projections -# n_heads = self.args.num_heads -# z = zxbcdt[:, :, :self.args.intermediate_size] -# xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] -# dt = zxbcdt[:, :, -(n_heads):] - -# # Time steps -# dt = mx.reshape(dt, (batch_size, n_heads)) -# dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) -# dt = mx.maximum(dt, self.args.time_step_floor) - -# # Convolution -# xBC = self.conv1d(xBC, cache=cache) -# xBC = silu(xBC) - -# # Split states -# x = xBC[:, :, :self.args.intermediate_size] -# B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] -# C = xBC[:, :, -self.args.state_size:] - -# # Reshape for SSM -# x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) -# x = mx.squeeze(x, axis=1) -# B = mx.reshape(B, (batch_size, 1, self.args.state_size)) -# B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) -# B = mx.expand_dims(B, axis=2) -# C = mx.reshape(C, (batch_size, 1, self.args.state_size)) -# C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) -# C = mx.expand_dims(C, axis=3) - -# # SSM updates -# A = -mx.exp(self.A_log) -# dA = mx.exp(dt * mx.expand_dims(A, 0)) -# dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - -# # Update state -# x = mx.expand_dims(x, axis=3) -# dBx = mx.matmul(x, B) -# cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx - -# # Compute output -# y = mx.matmul(cache.ssm_states[0], C) -# y = mx.squeeze(y, axis=-1) -# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) -# y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) -# y = self.norm(y + z) - -# return self.out_proj(y) - - class DepthWiseConv1d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): super().__init__() @@ -313,6 +151,97 @@ class DepthWiseConv1d(nn.Module): return y +# class Mamba2Block(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.args = args + +# d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads +# self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) + +# conv_dim = args.intermediate_size + 2 * args.state_size +# self.conv1d = DepthWiseConv1d( +# in_channels=conv_dim, +# out_channels=conv_dim, +# kernel_size=args.conv_kernel, +# groups=conv_dim, +# bias=args.use_conv_bias, +# padding=args.conv_kernel - 1 +# ) + +# self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range +# self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range +# self.D = mx.random.normal((args.num_heads,)) * args.initializer_range + +# self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) +# self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + +# if args.rescale_prenorm_residual: +# layer_scale = math.sqrt(1.0 / args.num_hidden_layers) +# self.out_proj.weight = self.out_proj.weight * layer_scale + +# def __call__(self, u: mx.array, cache=None): +# batch_size, seq_len, dimension = u.shape +# assert seq_len == 1, "Input should be a single token" + +# # Initialize cache states directly using indices +# if cache[0] is None: # conv state +# conv_dim = self.args.intermediate_size + 2 * self.args.state_size +# cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) + +# if cache[1] is None: # ssm state +# cache[1] = mx.zeros(( +# batch_size, +# self.args.num_heads, +# self.args.head_dim, +# self.args.state_size +# )) + +# zxbcdt = self.in_proj(u) + +# n_heads = self.args.num_heads +# z = zxbcdt[:, :, :self.args.intermediate_size] +# xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] +# dt = zxbcdt[:, :, -(n_heads):] + +# dt = mx.reshape(dt, (batch_size, n_heads)) +# dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) +# dt = mx.maximum(dt, self.args.time_step_floor) + +# xBC = self.conv1d(xBC, cache=cache) +# xBC = silu(xBC) + +# x = xBC[:, :, :self.args.intermediate_size] +# B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] +# C = xBC[:, :, -self.args.state_size:] + +# x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) +# x = mx.squeeze(x, axis=1) +# B = mx.reshape(B, (batch_size, 1, self.args.state_size)) +# B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) +# B = mx.expand_dims(B, axis=2) +# C = mx.reshape(C, (batch_size, 1, self.args.state_size)) +# C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) +# C = mx.expand_dims(C, axis=3) + +# A = -mx.exp(self.A_log) +# dA = mx.exp(dt * mx.expand_dims(A, 0)) +# dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + +# x = mx.expand_dims(x, axis=3) +# dBx = mx.matmul(x, B) +# # Update ssm state directly using cache[1] +# cache[1] = cache[1] * dA + dBx + +# y = mx.matmul(cache[1], C) +# y = mx.squeeze(y, axis=-1) +# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) +# y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) +# y = self.norm(y + z) + +# return self.out_proj(y) + + class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -344,64 +273,81 @@ class Mamba2Block(nn.Module): def __call__(self, u: mx.array, cache=None): batch_size, seq_len, dimension = u.shape - assert seq_len == 1, "Input should be a single token" - - # Initialize cache states directly using indices - if cache[0] is None: # conv state - conv_dim = self.args.intermediate_size + 2 * self.args.state_size - cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) - - if cache[1] is None: # ssm state - cache[1] = mx.zeros(( - batch_size, - self.args.num_heads, - self.args.head_dim, - self.args.state_size - )) - - zxbcdt = self.in_proj(u) - n_heads = self.args.num_heads - z = zxbcdt[:, :, :self.args.intermediate_size] - xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] - dt = zxbcdt[:, :, -(n_heads):] + # Process sequence in chunks if needed + outputs = [] + current_cache = cache + + for i in range(seq_len): + # Extract current token + current_input = u[:, i:i+1, :] + + # Initialize cache states if needed + if current_cache[0] is None: # conv state + conv_dim = self.args.intermediate_size + 2 * self.args.state_size + current_cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) - dt = mx.reshape(dt, (batch_size, n_heads)) - dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) - dt = mx.maximum(dt, self.args.time_step_floor) + if current_cache[1] is None: # ssm state + current_cache[1] = mx.zeros(( + batch_size, + self.args.num_heads, + self.args.head_dim, + self.args.state_size + )) - xBC = self.conv1d(xBC, cache=cache) - xBC = silu(xBC) + # Project input + zxbcdt = self.in_proj(current_input) + + n_heads = self.args.num_heads + z = zxbcdt[:, :, :self.args.intermediate_size] + xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] + dt = zxbcdt[:, :, -(n_heads):] - x = xBC[:, :, :self.args.intermediate_size] - B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] - C = xBC[:, :, -self.args.state_size:] + # Process time steps + dt = mx.reshape(dt, (batch_size, n_heads)) + dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) + dt = mx.maximum(dt, self.args.time_step_floor) - x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) - x = mx.squeeze(x, axis=1) - B = mx.reshape(B, (batch_size, 1, self.args.state_size)) - B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) - B = mx.expand_dims(B, axis=2) - C = mx.reshape(C, (batch_size, 1, self.args.state_size)) - C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) - C = mx.expand_dims(C, axis=3) + # Apply convolution + xBC = self.conv1d(xBC, cache=current_cache) + xBC = silu(xBC) - A = -mx.exp(self.A_log) - dA = mx.exp(dt * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + # Split states + x = xBC[:, :, :self.args.intermediate_size] + B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] + C = xBC[:, :, -self.args.state_size:] - x = mx.expand_dims(x, axis=3) - dBx = mx.matmul(x, B) - # Update ssm state directly using cache[1] - cache[1] = cache[1] * dA + dBx + # Reshape for SSM + x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) + x = mx.squeeze(x, axis=1) + B = mx.reshape(B, (batch_size, 1, self.args.state_size)) + B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) + B = mx.expand_dims(B, axis=2) + C = mx.reshape(C, (batch_size, 1, self.args.state_size)) + C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) + C = mx.expand_dims(C, axis=3) - y = mx.matmul(cache[1], C) - y = mx.squeeze(y, axis=-1) - y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) - y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) - y = self.norm(y + z) + # SSM updates + A = -mx.exp(self.A_log) + dA = mx.exp(dt * mx.expand_dims(A, 0)) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - return self.out_proj(y) + # Update state + x = mx.expand_dims(x, axis=3) + dBx = mx.matmul(x, B) + current_cache[1] = current_cache[1] * dA + dBx + + # Compute output + y = mx.matmul(current_cache[1], C) + y = mx.squeeze(y, axis=-1) + y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) + y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) + y = self.norm(y + z) + + outputs.append(self.out_proj(y)) + + # Concatenate all outputs + return mx.concatenate(outputs, axis=1) class ResidualBlock(nn.Module):