Optimizations for mamba1 (#1213)

* added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec

* Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec

* Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec

* Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec.

* cleaning up and adding apple copyright to helium modelfile

* update Copyright to this year

* nits + even faster

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Gökdeniz Gülmez 2025-02-03 22:36:08 +01:00 committed by GitHub
parent d9924d08d1
commit 0989c073b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 26 deletions

View File

@ -1,3 +1,5 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple

View File

@ -1,4 +1,4 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024-2025 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias self.intermediate_size, self.hidden_size, bias=args.use_bias
) )
def ssm_step(self, x, state=None): def ssm_step(self, x, A, state=None):
A = -mx.exp(self.A_log)
D = self.D D = self.D
deltaBC = self.x_proj(x) deltaBC = self.x_proj(x)
delta, B, C = mx.split( delta, B, C = map(
deltaBC, self.mixer_norm if self.use_bcdt_rms else lambda x: x,
indices_or_sections=[ mx.split(
self.time_step_rank, deltaBC,
self.time_step_rank + self.ssm_state_size, [self.time_step_rank, self.time_step_rank + self.ssm_state_size],
], axis=-1,
axis=-1, ),
) )
if self.use_bcdt_rms: if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C)) delta, B, C = map(self.mixer_norm, (delta, B, C))
@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
y = y + D * x y = y + D * x
return y, new_state return y, new_state
def __call__(self, x, cache): def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape B, T, D = x.shape
if cache is None: xz = self.in_proj(x)
cache = [None, None] x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
outputs = [] outputs = []
current_state = state_cache
y = []
for t in range(T): for t in range(T):
xt = x[:, t, :] y_t, current_state = self.ssm_step(x[:, t], A, current_state)
xz = self.in_proj(xt) y.append(y_t)
x_t, z_t = xz.split(indices_or_sections=2, axis=1) y = mx.stack(y, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) z = self.out_proj(nn.silu(z) * y)
x_t = conv_out.squeeze(1) return z, (new_conv_cache, current_state)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1]) def __call__(self, x, cache):
z_t = nn.silu(z_t) if cache is None:
output_t = y_t * z_t conv_cache, state_cache = None, None
output_t = self.out_proj(output_t) else:
outputs.append(output_t) conv_cache, state_cache = cache[0], cache[1]
output = mx.stack(outputs, axis=1)
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output return output

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union