mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
cleaning up and adding apple copyright to helium modelfile
This commit is contained in:
parent
7b29cf0eda
commit
0d4f2c4dc0
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
@ -148,20 +148,15 @@ class MambaBlock(nn.Module):
|
|||||||
return y, new_state
|
return y, new_state
|
||||||
|
|
||||||
def _process_sequence(self, x, conv_cache, state_cache):
|
def _process_sequence(self, x, conv_cache, state_cache):
|
||||||
"""Process a sequence of inputs with cached states"""
|
|
||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
# Project all tokens at once
|
|
||||||
xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1)
|
xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1)
|
||||||
x_t, z_t = xz.split(indices_or_sections=2, axis=-1) # Fixed: using split instead of chunk
|
x_t, z_t = xz.split(indices_or_sections=2, axis=-1)
|
||||||
|
|
||||||
# Handle convolution with cache
|
|
||||||
conv_out, new_conv_cache = self.conv1d(x_t, conv_cache)
|
conv_out, new_conv_cache = self.conv1d(x_t, conv_cache)
|
||||||
x_t = nn.silu(conv_out)
|
x_t = nn.silu(conv_out)
|
||||||
|
|
||||||
# Pre-compute A matrix
|
|
||||||
A = -mx.exp(self.A_log)
|
A = -mx.exp(self.A_log)
|
||||||
|
|
||||||
# Process sequence with state
|
|
||||||
outputs = []
|
outputs = []
|
||||||
current_state = state_cache
|
current_state = state_cache
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
@ -174,17 +169,14 @@ class MambaBlock(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x, cache):
|
def __call__(self, x, cache):
|
||||||
if cache is None or isinstance(cache, list):
|
if cache is None or isinstance(cache, list):
|
||||||
# Handle legacy cache format
|
|
||||||
conv_cache, state_cache = cache if cache is not None else (None, None)
|
conv_cache, state_cache = cache if cache is not None else (None, None)
|
||||||
else:
|
else:
|
||||||
# Handle MambaCache object
|
|
||||||
conv_cache, state_cache = cache.state
|
conv_cache, state_cache = cache.state
|
||||||
|
|
||||||
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||||
x, conv_cache, state_cache
|
x, conv_cache, state_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update cache
|
|
||||||
if isinstance(cache, MambaCache):
|
if isinstance(cache, MambaCache):
|
||||||
cache[0] = new_conv_cache
|
cache[0] = new_conv_cache
|
||||||
cache[1] = new_state_cache
|
cache[1] = new_state_cache
|
||||||
|
Loading…
Reference in New Issue
Block a user