2024-08-17 06:28:39 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-03-03 11:39:23 +08:00
|
|
|
from dataclasses import dataclass
|
2024-10-08 11:45:51 +08:00
|
|
|
from typing import Any, Optional
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
from .base import BaseModelArgs, create_attention_mask
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ModelArgs(BaseModelArgs):
|
|
|
|
model_type: str
|
|
|
|
hidden_size: int
|
|
|
|
num_hidden_layers: int
|
|
|
|
intermediate_size: int
|
|
|
|
num_attention_heads: int
|
2024-03-08 01:31:57 +08:00
|
|
|
num_key_value_heads: int
|
2024-03-03 22:12:03 +08:00
|
|
|
norm_epsilon: float = 1e-5
|
2024-03-03 11:39:23 +08:00
|
|
|
vocab_size: int = 49152
|
|
|
|
rope_theta: float = 100000
|
|
|
|
tie_word_embeddings: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
dim = args.hidden_size
|
|
|
|
self.n_heads = n_heads = args.num_attention_heads
|
|
|
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
|
|
|
|
|
|
head_dim = args.hidden_size // args.num_attention_heads
|
|
|
|
self.scale = head_dim**-0.5
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
|
|
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
|
|
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
|
|
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
|
|
|
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta)
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
x: mx.array,
|
|
|
|
mask: Optional[mx.array] = None,
|
2024-10-08 11:45:51 +08:00
|
|
|
cache: Optional[Any] = None,
|
2024-03-03 11:39:23 +08:00
|
|
|
) -> mx.array:
|
|
|
|
B, L, D = x.shape
|
|
|
|
|
|
|
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
|
|
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
|
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
|
|
|
|
if cache is not None:
|
2024-05-08 23:18:13 +08:00
|
|
|
queries = self.rope(queries, offset=cache.offset)
|
|
|
|
keys = self.rope(keys, offset=cache.offset)
|
|
|
|
keys, values = cache.update_and_fetch(keys, values)
|
2024-03-03 11:39:23 +08:00
|
|
|
else:
|
|
|
|
queries = self.rope(queries)
|
|
|
|
keys = self.rope(keys)
|
|
|
|
|
2024-03-08 09:41:23 +08:00
|
|
|
output = mx.fast.scaled_dot_product_attention(
|
|
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
|
|
)
|
|
|
|
|
|
|
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
2024-05-08 23:18:13 +08:00
|
|
|
return self.o_proj(output)
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
|
def __init__(self, dim, hidden_dim):
|
|
|
|
super().__init__()
|
|
|
|
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
|
|
|
|
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)
|
|
|
|
|
|
|
|
def __call__(self, x):
|
|
|
|
return self.c_proj(nn.gelu(self.c_fc(x)))
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.hidden_size = args.hidden_size
|
|
|
|
self.n_heads = args.num_attention_heads
|
|
|
|
|
|
|
|
self.self_attn = Attention(args)
|
|
|
|
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
2024-03-23 22:13:51 +08:00
|
|
|
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
|
|
|
self.post_attention_layernorm = nn.LayerNorm(
|
2024-03-03 22:12:03 +08:00
|
|
|
args.hidden_size, eps=args.norm_epsilon
|
2024-03-03 11:39:23 +08:00
|
|
|
)
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
x: mx.array,
|
|
|
|
mask: Optional[mx.array] = None,
|
2024-10-08 11:45:51 +08:00
|
|
|
cache: Optional[Any] = None,
|
2024-03-03 11:39:23 +08:00
|
|
|
) -> mx.array:
|
2024-05-08 23:18:13 +08:00
|
|
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
2024-03-03 11:39:23 +08:00
|
|
|
h = x + r
|
|
|
|
r = self.mlp(self.post_attention_layernorm(h))
|
|
|
|
out = h + r
|
2024-05-08 23:18:13 +08:00
|
|
|
return out
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Starcoder2Model(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.args = args
|
|
|
|
self.vocab_size = args.vocab_size
|
|
|
|
self.num_hidden_layers = args.num_hidden_layers
|
|
|
|
assert self.vocab_size > 0
|
|
|
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
|
|
self.layers = [
|
|
|
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
|
|
|
]
|
2024-03-23 22:13:51 +08:00
|
|
|
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
inputs: mx.array,
|
|
|
|
cache=None,
|
|
|
|
):
|
|
|
|
h = self.embed_tokens(inputs)
|
|
|
|
|
2024-07-26 07:45:22 +08:00
|
|
|
mask = create_attention_mask(h, cache)
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
if cache is None:
|
|
|
|
cache = [None] * len(self.layers)
|
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
for layer, c in zip(self.layers, cache):
|
|
|
|
h = layer(h, mask, c)
|
2024-03-03 11:39:23 +08:00
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
return self.norm(h)
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
2024-03-13 12:34:32 +08:00
|
|
|
self.args = args
|
2024-03-03 22:07:45 +08:00
|
|
|
self.model_type = args.model_type
|
2024-03-03 11:39:23 +08:00
|
|
|
self.model = Starcoder2Model(args)
|
2024-04-19 09:16:10 +08:00
|
|
|
if not args.tie_word_embeddings:
|
2024-04-30 04:14:45 +08:00
|
|
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
inputs: mx.array,
|
|
|
|
cache=None,
|
|
|
|
):
|
2024-05-08 23:18:13 +08:00
|
|
|
out = self.model(inputs, cache)
|
2024-04-19 09:16:10 +08:00
|
|
|
if self.args.tie_word_embeddings:
|
|
|
|
out = self.model.embed_tokens.as_linear(out)
|
|
|
|
else:
|
|
|
|
out = self.lm_head(out)
|
2024-05-08 23:18:13 +08:00
|
|
|
return out
|
2024-03-03 11:39:23 +08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def layers(self):
|
|
|
|
return self.model.layers
|