mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-08 10:14:36 +08:00
fixing generate and logits outputs
This commit is contained in:
@@ -9,17 +9,19 @@ import mlx.nn as nn
|
||||
|
||||
|
||||
class MambaCache:
|
||||
def __init__(self, batch_size, intermediate_size, ssm_state_size, conv_kernel_size):
|
||||
self.h = mx.zeros((batch_size, intermediate_size, ssm_state_size))
|
||||
self.conv_states = mx.zeros((batch_size, conv_kernel_size - 1, intermediate_size))
|
||||
def __init__(self, num_layers, conv_state_size, ssm_state_size):
|
||||
self.conv_states = [None for _ in range(num_layers)]
|
||||
self.ssm_states = [None for _ in range(num_layers)]
|
||||
self.offset = 0
|
||||
|
||||
def update(self, new_h, new_conv_state):
|
||||
self.h = new_h
|
||||
self.conv_states = mx.concatenate([self.conv_states[:, 1:, :], new_conv_state], axis=1)
|
||||
def update(self, layer_idx, conv_state, ssm_state):
|
||||
self.conv_states[layer_idx] = conv_state
|
||||
self.ssm_states[layer_idx] = ssm_state
|
||||
self.offset += 1
|
||||
|
||||
@classmethod
|
||||
def init_cache(cls, batch_size, intermediate_size, ssm_state_size, conv_kernel_size):
|
||||
return cls(batch_size, intermediate_size, ssm_state_size, conv_kernel_size)
|
||||
@property
|
||||
def state(self):
|
||||
return self.conv_states, self.ssm_states
|
||||
|
||||
class KVCache:
|
||||
|
||||
|
@@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import math
|
||||
|
||||
@@ -257,7 +256,6 @@ class Mamba(nn.Module):
|
||||
|
||||
def __call__(self, x: mx.array, caches):
|
||||
x = self.embeddings(x)
|
||||
print(x.shape)
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, caches[i] = layer(x, caches[i])
|
||||
return x, caches
|
||||
@@ -292,10 +290,6 @@ class Model(nn.Module):
|
||||
logits = self.backbone.embeddings.as_linear(x)
|
||||
else:
|
||||
logits = self.lm_head(x)
|
||||
|
||||
print(f"Logits shape: {logits.shape}")
|
||||
# logits : (B, T, vocab_size)
|
||||
print(logits)
|
||||
|
||||
return logits, cache
|
||||
|
||||
|
@@ -104,7 +104,6 @@ def linear_to_lora_layers(
|
||||
"cohere",
|
||||
"minicpm",
|
||||
"deepseek",
|
||||
"mamba"
|
||||
]:
|
||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||
if model.model_type in ["mixtral", "phimoe"]:
|
||||
|
@@ -19,7 +19,7 @@ from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache, RotatingKVCache, MambaCache
|
||||
from .models.base import KVCache, RotatingKVCache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import apply_lora_layers
|
||||
@@ -165,7 +165,7 @@ def generate_step(
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model: The model to use for generation.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
@@ -236,35 +236,36 @@ def generate_step(
|
||||
|
||||
def _step(y):
|
||||
nonlocal repetition_context
|
||||
logits = model(y[None], cache=cache)
|
||||
if model.model_type == "mamba":
|
||||
logits, _ = model(y[None], cache=cache)
|
||||
else:
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if repetition_penalty:
|
||||
logits = apply_repetition_penalty(
|
||||
logits, repetition_context, repetition_penalty
|
||||
)
|
||||
next_token, logprobs = sample(logits)
|
||||
repetition_context.append(next_token.item())
|
||||
y, logprobs = sample(logits)
|
||||
repetition_context.append(y.item())
|
||||
else:
|
||||
next_token, logprobs = sample(logits)
|
||||
y, logprobs = sample(logits)
|
||||
|
||||
if repetition_context_size:
|
||||
if len(repetition_context) > repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
|
||||
return next_token, logprobs.squeeze(0)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
if hasattr(model, 'generate_step'):
|
||||
y, logprobs = model.generate_step(prompt)
|
||||
else:
|
||||
y, logprobs = _step(y)
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y)
|
||||
while True:
|
||||
if hasattr(model, 'generate_step'):
|
||||
next_y, next_logprobs = model.generate_step(y)
|
||||
else:
|
||||
next_y, next_logprobs = _step(y)
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y)
|
||||
yield y.item(), logprobs
|
||||
y, logprobs = next_y, next_logprobs
|
||||
|
Reference in New Issue
Block a user