mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Fixing the cache handling, generating works now trying training
This commit is contained in:
@@ -6,22 +6,6 @@ from typing import Any, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class MambaCache:
|
||||
def __init__(self, num_layers, conv_state_size, ssm_state_size):
|
||||
self.conv_states = [None for _ in range(num_layers)]
|
||||
self.ssm_states = [None for _ in range(num_layers)]
|
||||
self.offset = 0
|
||||
|
||||
def update(self, layer_idx, conv_state, ssm_state):
|
||||
self.conv_states[layer_idx] = conv_state
|
||||
self.ssm_states[layer_idx] = ssm_state
|
||||
self.offset += 1
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.conv_states, self.ssm_states
|
||||
|
||||
|
||||
class KVCache:
|
||||
|
@@ -5,7 +5,7 @@ import math
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, MambaCache
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
|
Reference in New Issue
Block a user