mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
parent
20e221f7f7
commit
68e88d42fb
@ -14,7 +14,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
hidden_size: int
|
hidden_size: int
|
||||||
attention_bias: bool
|
attention_bias: bool
|
||||||
conv1d_width: int
|
conv1d_width: int
|
||||||
embeddings_scale_by_sqrt_dim: bool
|
|
||||||
hidden_size: int
|
hidden_size: int
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
logits_soft_cap: float
|
logits_soft_cap: float
|
||||||
@ -25,7 +24,14 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rope_theta: float
|
rope_theta: float
|
||||||
attention_window_size: int
|
attention_window_size: int
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
_block_types: List[str]
|
embeddings_scale_by_sqrt_dim: bool = True
|
||||||
|
block_types: Optional[List[str]] = None
|
||||||
|
_block_types: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# For some reason these have different names in 2B and 9B
|
||||||
|
if self.block_types is None:
|
||||||
|
self.block_types = self._block_types
|
||||||
|
|
||||||
|
|
||||||
def create_window_causal_mask(N: int, window_size: int):
|
def create_window_causal_mask(N: int, window_size: int):
|
||||||
@ -202,6 +208,8 @@ class RGLRU(nn.Module):
|
|||||||
|
|
||||||
# Apply gamma normalization to the input.
|
# Apply gamma normalization to the input.
|
||||||
multiplier = mx.sqrt(1 - a_square)
|
multiplier = mx.sqrt(1 - a_square)
|
||||||
|
if cache is None:
|
||||||
|
multiplier[:, 0, :] = 1.0
|
||||||
normalized_x = gated_x * multiplier.astype(x.dtype)
|
normalized_x = gated_x * multiplier.astype(x.dtype)
|
||||||
|
|
||||||
y, last_h = rnn_scan(
|
y, last_h = rnn_scan(
|
||||||
@ -404,8 +412,8 @@ class ResidualBlock(nn.Module):
|
|||||||
raw_x = x
|
raw_x = x
|
||||||
|
|
||||||
inputs_normalized = self.temporal_pre_norm(raw_x)
|
inputs_normalized = self.temporal_pre_norm(raw_x)
|
||||||
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
|
|
||||||
|
|
||||||
|
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
|
||||||
residual = x + raw_x
|
residual = x + raw_x
|
||||||
|
|
||||||
x = self.channel_pre_norm(residual)
|
x = self.channel_pre_norm(residual)
|
||||||
@ -427,7 +435,7 @@ class Griffin(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
|
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
|
||||||
block_types = config._block_types
|
block_types = config.block_types
|
||||||
|
|
||||||
self.layers = [
|
self.layers = [
|
||||||
ResidualBlock(
|
ResidualBlock(
|
||||||
@ -461,14 +469,7 @@ class Griffin(nn.Module):
|
|||||||
for i, block in enumerate(self.layers):
|
for i, block in enumerate(self.layers):
|
||||||
x = block(x, mask=mask, cache=cache[i])
|
x = block(x, mask=mask, cache=cache[i])
|
||||||
|
|
||||||
x = self.final_norm(x)
|
return self.final_norm(x)
|
||||||
logits = self.embed_tokens.as_linear(x)
|
|
||||||
|
|
||||||
c = self.config.logits_soft_cap
|
|
||||||
if c:
|
|
||||||
logits = mx.tanh(logits / c) * c
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -476,13 +477,23 @@ class Model(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.args = config
|
self.args = config
|
||||||
self.model = Griffin(config)
|
self.model = Griffin(config)
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
tokens: Sequence of input tokens.
|
tokens: Sequence of input tokens.
|
||||||
"""
|
"""
|
||||||
return self.model(tokens, cache=cache)
|
logits = self.model(tokens, cache=cache)
|
||||||
|
if "lm_head" in self:
|
||||||
|
logits = self.lm_head(logits)
|
||||||
|
else:
|
||||||
|
logits = self.model.embed_tokens.as_linear(logits)
|
||||||
|
|
||||||
|
c = self.args.logits_soft_cap
|
||||||
|
if c:
|
||||||
|
logits = mx.tanh(logits / c) * c
|
||||||
|
return logits
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
@ -493,6 +504,8 @@ class Model(nn.Module):
|
|||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if "conv_1d.weight" in k and v.ndim == 3:
|
if "conv_1d.weight" in k and v.ndim == 3:
|
||||||
weights[k] = v.squeeze(1).T
|
weights[k] = v.squeeze(1).T
|
||||||
|
if "lm_head.weight" not in weights:
|
||||||
|
self.pop("lm_head")
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
def make_cache(self):
|
def make_cache(self):
|
||||||
|
@ -173,6 +173,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
endpoints = {
|
endpoints = {
|
||||||
"/v1/completions": self.handle_text_completions,
|
"/v1/completions": self.handle_text_completions,
|
||||||
"/v1/chat/completions": self.handle_chat_completions,
|
"/v1/chat/completions": self.handle_chat_completions,
|
||||||
|
"/chat/completions": self.handle_chat_completions,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.path not in endpoints:
|
if self.path not in endpoints:
|
||||||
|
Loading…
Reference in New Issue
Block a user