mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
parent
20e221f7f7
commit
68e88d42fb
@ -14,7 +14,6 @@ class ModelArgs(BaseModelArgs):
|
||||
hidden_size: int
|
||||
attention_bias: bool
|
||||
conv1d_width: int
|
||||
embeddings_scale_by_sqrt_dim: bool
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
logits_soft_cap: float
|
||||
@ -25,7 +24,14 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_theta: float
|
||||
attention_window_size: int
|
||||
vocab_size: int
|
||||
_block_types: List[str]
|
||||
embeddings_scale_by_sqrt_dim: bool = True
|
||||
block_types: Optional[List[str]] = None
|
||||
_block_types: Optional[List[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# For some reason these have different names in 2B and 9B
|
||||
if self.block_types is None:
|
||||
self.block_types = self._block_types
|
||||
|
||||
|
||||
def create_window_causal_mask(N: int, window_size: int):
|
||||
@ -202,6 +208,8 @@ class RGLRU(nn.Module):
|
||||
|
||||
# Apply gamma normalization to the input.
|
||||
multiplier = mx.sqrt(1 - a_square)
|
||||
if cache is None:
|
||||
multiplier[:, 0, :] = 1.0
|
||||
normalized_x = gated_x * multiplier.astype(x.dtype)
|
||||
|
||||
y, last_h = rnn_scan(
|
||||
@ -404,8 +412,8 @@ class ResidualBlock(nn.Module):
|
||||
raw_x = x
|
||||
|
||||
inputs_normalized = self.temporal_pre_norm(raw_x)
|
||||
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
|
||||
|
||||
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
|
||||
residual = x + raw_x
|
||||
|
||||
x = self.channel_pre_norm(residual)
|
||||
@ -427,7 +435,7 @@ class Griffin(nn.Module):
|
||||
)
|
||||
|
||||
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
|
||||
block_types = config._block_types
|
||||
block_types = config.block_types
|
||||
|
||||
self.layers = [
|
||||
ResidualBlock(
|
||||
@ -461,14 +469,7 @@ class Griffin(nn.Module):
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
|
||||
x = self.final_norm(x)
|
||||
logits = self.embed_tokens.as_linear(x)
|
||||
|
||||
c = self.config.logits_soft_cap
|
||||
if c:
|
||||
logits = mx.tanh(logits / c) * c
|
||||
|
||||
return logits
|
||||
return self.final_norm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@ -476,13 +477,23 @@ class Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
self.args = config
|
||||
self.model = Griffin(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
tokens: Sequence of input tokens.
|
||||
"""
|
||||
return self.model(tokens, cache=cache)
|
||||
logits = self.model(tokens, cache=cache)
|
||||
if "lm_head" in self:
|
||||
logits = self.lm_head(logits)
|
||||
else:
|
||||
logits = self.model.embed_tokens.as_linear(logits)
|
||||
|
||||
c = self.args.logits_soft_cap
|
||||
if c:
|
||||
logits = mx.tanh(logits / c) * c
|
||||
return logits
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
@ -493,6 +504,8 @@ class Model(nn.Module):
|
||||
for k, v in weights.items():
|
||||
if "conv_1d.weight" in k and v.ndim == 3:
|
||||
weights[k] = v.squeeze(1).T
|
||||
if "lm_head.weight" not in weights:
|
||||
self.pop("lm_head")
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
|
@ -173,6 +173,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
endpoints = {
|
||||
"/v1/completions": self.handle_text_completions,
|
||||
"/v1/chat/completions": self.handle_chat_completions,
|
||||
"/chat/completions": self.handle_chat_completions,
|
||||
}
|
||||
|
||||
if self.path not in endpoints:
|
||||
|
Loading…
Reference in New Issue
Block a user