mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Optimizations for mamba1 (#1213)
* added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec * Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec * Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec * Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec. * cleaning up and adding apple copyright to helium modelfile * update Copyright to this year * nits + even faster --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
d9924d08d1
commit
0989c073b0
@ -1,3 +1,5 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
# Copyright © 2024-2025 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
|
||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||
)
|
||||
|
||||
def ssm_step(self, x, state=None):
|
||||
A = -mx.exp(self.A_log)
|
||||
def ssm_step(self, x, A, state=None):
|
||||
D = self.D
|
||||
deltaBC = self.x_proj(x)
|
||||
delta, B, C = mx.split(
|
||||
delta, B, C = map(
|
||||
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||
mx.split(
|
||||
deltaBC,
|
||||
indices_or_sections=[
|
||||
self.time_step_rank,
|
||||
self.time_step_rank + self.ssm_state_size,
|
||||
],
|
||||
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
||||
axis=-1,
|
||||
),
|
||||
)
|
||||
if self.use_bcdt_rms:
|
||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||
@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
|
||||
y = y + D * x
|
||||
return y, new_state
|
||||
|
||||
def __call__(self, x, cache):
|
||||
def _process_sequence(self, x, conv_cache, state_cache):
|
||||
B, T, D = x.shape
|
||||
if cache is None:
|
||||
cache = [None, None]
|
||||
xz = self.in_proj(x)
|
||||
x, z = xz.split(indices_or_sections=2, axis=-1)
|
||||
|
||||
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
||||
x = nn.silu(conv_out)
|
||||
|
||||
A = -mx.exp(self.A_log)
|
||||
|
||||
outputs = []
|
||||
current_state = state_cache
|
||||
y = []
|
||||
for t in range(T):
|
||||
xt = x[:, t, :]
|
||||
xz = self.in_proj(xt)
|
||||
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
||||
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
||||
x_t = conv_out.squeeze(1)
|
||||
x_t = nn.silu(x_t)
|
||||
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
||||
z_t = nn.silu(z_t)
|
||||
output_t = y_t * z_t
|
||||
output_t = self.out_proj(output_t)
|
||||
outputs.append(output_t)
|
||||
output = mx.stack(outputs, axis=1)
|
||||
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
||||
y.append(y_t)
|
||||
y = mx.stack(y, axis=1)
|
||||
z = self.out_proj(nn.silu(z) * y)
|
||||
return z, (new_conv_cache, current_state)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
if cache is None:
|
||||
conv_cache, state_cache = None, None
|
||||
else:
|
||||
conv_cache, state_cache = cache[0], cache[1]
|
||||
|
||||
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||
x, conv_cache, state_cache
|
||||
)
|
||||
|
||||
if isinstance(cache, MambaCache):
|
||||
cache[0] = new_conv_cache
|
||||
cache[1] = new_state_cache
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
# Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
Loading…
Reference in New Issue
Block a user