Fixing the cache handling, generating works now trying training

This commit is contained in:
Goekdeniz-Guelmez
2024-09-04 23:00:25 +02:00
parent 107575133e
commit fd3bd6d9aa
2 changed files with 1 additions and 17 deletions

View File

@@ -6,22 +6,6 @@ from typing import Any, List, Optional
import mlx.core as mx
import mlx.nn as nn
class MambaCache:
def __init__(self, num_layers, conv_state_size, ssm_state_size):
self.conv_states = [None for _ in range(num_layers)]
self.ssm_states = [None for _ in range(num_layers)]
self.offset = 0
def update(self, layer_idx, conv_state, ssm_state):
self.conv_states[layer_idx] = conv_state
self.ssm_states[layer_idx] = ssm_state
self.offset += 1
@property
def state(self):
return self.conv_states, self.ssm_states
class KVCache:

View File

@@ -5,7 +5,7 @@ import math
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, MambaCache
from .base import BaseModelArgs
@dataclass