diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 1988abd4..428431e3 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -14,7 +14,6 @@ class ModelArgs(BaseModelArgs): hidden_size: int attention_bias: bool conv1d_width: int - embeddings_scale_by_sqrt_dim: bool hidden_size: int intermediate_size: int logits_soft_cap: float @@ -25,7 +24,14 @@ class ModelArgs(BaseModelArgs): rope_theta: float attention_window_size: int vocab_size: int - _block_types: List[str] + embeddings_scale_by_sqrt_dim: bool = True + block_types: Optional[List[str]] = None + _block_types: Optional[List[str]] = None + + def __post_init__(self): + # For some reason these have different names in 2B and 9B + if self.block_types is None: + self.block_types = self._block_types def create_window_causal_mask(N: int, window_size: int): @@ -202,6 +208,8 @@ class RGLRU(nn.Module): # Apply gamma normalization to the input. multiplier = mx.sqrt(1 - a_square) + if cache is None: + multiplier[:, 0, :] = 1.0 normalized_x = gated_x * multiplier.astype(x.dtype) y, last_h = rnn_scan( @@ -404,8 +412,8 @@ class ResidualBlock(nn.Module): raw_x = x inputs_normalized = self.temporal_pre_norm(raw_x) - x = self.temporal_block(inputs_normalized, cache=cache, mask=mask) + x = self.temporal_block(inputs_normalized, cache=cache, mask=mask) residual = x + raw_x x = self.channel_pre_norm(residual) @@ -427,7 +435,7 @@ class Griffin(nn.Module): ) self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim - block_types = config._block_types + block_types = config.block_types self.layers = [ ResidualBlock( @@ -461,14 +469,7 @@ class Griffin(nn.Module): for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) - x = self.final_norm(x) - logits = self.embed_tokens.as_linear(x) - - c = self.config.logits_soft_cap - if c: - logits = mx.tanh(logits / c) * c - - return logits + return self.final_norm(x) class Model(nn.Module): @@ -476,13 +477,23 @@ class Model(nn.Module): def __init__(self, config): self.args = config self.model = Griffin(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__(self, tokens: mx.array, cache=None) -> mx.array: """ Args: tokens: Sequence of input tokens. """ - return self.model(tokens, cache=cache) + logits = self.model(tokens, cache=cache) + if "lm_head" in self: + logits = self.lm_head(logits) + else: + logits = self.model.embed_tokens.as_linear(logits) + + c = self.args.logits_soft_cap + if c: + logits = mx.tanh(logits / c) * c + return logits @property def layers(self): @@ -493,6 +504,8 @@ class Model(nn.Module): for k, v in weights.items(): if "conv_1d.weight" in k and v.ndim == 3: weights[k] = v.squeeze(1).T + if "lm_head.weight" not in weights: + self.pop("lm_head") return weights def make_cache(self): diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index b53971a3..bd23cca2 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -173,6 +173,7 @@ class APIHandler(BaseHTTPRequestHandler): endpoints = { "/v1/completions": self.handle_text_completions, "/v1/chat/completions": self.handle_chat_completions, + "/chat/completions": self.handle_chat_completions, } if self.path not in endpoints: