cleaning up and adding apple copyright to helium modelfile

This commit is contained in:
Goekdeniz-Guelmez 2025-01-28 21:02:50 +01:00
parent 7b29cf0eda
commit 0d4f2c4dc0
2 changed files with 3 additions and 9 deletions

View File

@ -1,3 +1,5 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple

View File

@ -148,20 +148,15 @@ class MambaBlock(nn.Module):
return y, new_state return y, new_state
def _process_sequence(self, x, conv_cache, state_cache): def _process_sequence(self, x, conv_cache, state_cache):
"""Process a sequence of inputs with cached states"""
B, T, D = x.shape B, T, D = x.shape
# Project all tokens at once
xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1) xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1)
x_t, z_t = xz.split(indices_or_sections=2, axis=-1) # Fixed: using split instead of chunk x_t, z_t = xz.split(indices_or_sections=2, axis=-1)
# Handle convolution with cache
conv_out, new_conv_cache = self.conv1d(x_t, conv_cache) conv_out, new_conv_cache = self.conv1d(x_t, conv_cache)
x_t = nn.silu(conv_out) x_t = nn.silu(conv_out)
# Pre-compute A matrix
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
# Process sequence with state
outputs = [] outputs = []
current_state = state_cache current_state = state_cache
for t in range(T): for t in range(T):
@ -174,17 +169,14 @@ class MambaBlock(nn.Module):
def __call__(self, x, cache): def __call__(self, x, cache):
if cache is None or isinstance(cache, list): if cache is None or isinstance(cache, list):
# Handle legacy cache format
conv_cache, state_cache = cache if cache is not None else (None, None) conv_cache, state_cache = cache if cache is not None else (None, None)
else: else:
# Handle MambaCache object
conv_cache, state_cache = cache.state conv_cache, state_cache = cache.state
output, (new_conv_cache, new_state_cache) = self._process_sequence( output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache x, conv_cache, state_cache
) )
# Update cache
if isinstance(cache, MambaCache): if isinstance(cache, MambaCache):
cache[0] = new_conv_cache cache[0] = new_conv_cache
cache[1] = new_state_cache cache[1] = new_state_cache