From fd3bd6d9aac8b12c2801d7ad8f8a76e6222d66d4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 4 Sep 2024 23:00:25 +0200 Subject: [PATCH] Fixing the cache handling, generating works now trying training --- llms/mlx_lm/models/base.py | 16 ---------------- llms/mlx_lm/models/mamba.py | 2 +- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 92d14c5a..1a5cd42b 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -6,22 +6,6 @@ from typing import Any, List, Optional import mlx.core as mx import mlx.nn as nn - - -class MambaCache: - def __init__(self, num_layers, conv_state_size, ssm_state_size): - self.conv_states = [None for _ in range(num_layers)] - self.ssm_states = [None for _ in range(num_layers)] - self.offset = 0 - - def update(self, layer_idx, conv_state, ssm_state): - self.conv_states[layer_idx] = conv_state - self.ssm_states[layer_idx] = ssm_state - self.offset += 1 - - @property - def state(self): - return self.conv_states, self.ssm_states class KVCache: diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 47ee4a81..49c0ea11 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -5,7 +5,7 @@ import math import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, MambaCache +from .base import BaseModelArgs @dataclass