fixing generate and logits outputs

This commit is contained in:
Goekdeniz-Guelmez
2024-09-04 22:46:45 +02:00
parent 236acb16a8
commit de1fdc7fdf
4 changed files with 28 additions and 32 deletions

View File

@@ -9,17 +9,19 @@ import mlx.nn as nn
class MambaCache:
def __init__(self, batch_size, intermediate_size, ssm_state_size, conv_kernel_size):
self.h = mx.zeros((batch_size, intermediate_size, ssm_state_size))
self.conv_states = mx.zeros((batch_size, conv_kernel_size - 1, intermediate_size))
def __init__(self, num_layers, conv_state_size, ssm_state_size):
self.conv_states = [None for _ in range(num_layers)]
self.ssm_states = [None for _ in range(num_layers)]
self.offset = 0
def update(self, new_h, new_conv_state):
self.h = new_h
self.conv_states = mx.concatenate([self.conv_states[:, 1:, :], new_conv_state], axis=1)
def update(self, layer_idx, conv_state, ssm_state):
self.conv_states[layer_idx] = conv_state
self.ssm_states[layer_idx] = ssm_state
self.offset += 1
@classmethod
def init_cache(cls, batch_size, intermediate_size, ssm_state_size, conv_kernel_size):
return cls(batch_size, intermediate_size, ssm_state_size, conv_kernel_size)
@property
def state(self):
return self.conv_states, self.ssm_states
class KVCache:

View File

@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional
import math
@@ -257,7 +256,6 @@ class Mamba(nn.Module):
def __call__(self, x: mx.array, caches):
x = self.embeddings(x)
print(x.shape)
for i, layer in enumerate(self.layers):
x, caches[i] = layer(x, caches[i])
return x, caches
@@ -292,10 +290,6 @@ class Model(nn.Module):
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
print(f"Logits shape: {logits.shape}")
# logits : (B, T, vocab_size)
print(logits)
return logits, cache

View File

@@ -104,7 +104,6 @@ def linear_to_lora_layers(
"cohere",
"minicpm",
"deepseek",
"mamba"
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type in ["mixtral", "phimoe"]:

View File

@@ -19,7 +19,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache, RotatingKVCache, MambaCache
from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
@@ -165,7 +165,7 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model: The model to use for generation.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
@@ -236,35 +236,36 @@ def generate_step(
def _step(y):
nonlocal repetition_context
logits = model(y[None], cache=cache)
if model.model_type == "mamba":
logits, _ = model(y[None], cache=cache)
else:
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
)
next_token, logprobs = sample(logits)
repetition_context.append(next_token.item())
y, logprobs = sample(logits)
repetition_context.append(y.item())
else:
next_token, logprobs = sample(logits)
y, logprobs = sample(logits)
if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
return next_token, logprobs.squeeze(0)
return y, logprobs.squeeze(0)
if hasattr(model, 'generate_step'):
y, logprobs = model.generate_step(prompt)
else:
y, logprobs = _step(y)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
y, logprobs = _step(y)
mx.async_eval(y)
while True:
if hasattr(model, 'generate_step'):
next_y, next_logprobs = model.generate_step(y)
else:
next_y, next_logprobs = _step(y)
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs