2024-08-17 06:28:39 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
import sys
|
2024-02-06 13:13:49 +08:00
|
|
|
from dataclasses import dataclass
|
2024-10-08 11:45:51 +08:00
|
|
|
from typing import Any, Optional, Tuple
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
|
2024-07-26 07:45:22 +08:00
|
|
|
from .base import BaseModelArgs, create_attention_mask
|
2024-02-06 13:13:49 +08:00
|
|
|
|
2024-02-13 02:51:02 +08:00
|
|
|
try:
|
|
|
|
import hf_olmo
|
|
|
|
except ImportError:
|
|
|
|
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
2024-10-08 11:45:51 +08:00
|
|
|
sys.exit(1)
|
2024-02-13 02:51:02 +08:00
|
|
|
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ModelArgs(BaseModelArgs):
|
2024-02-13 02:51:02 +08:00
|
|
|
model_type: str
|
2024-02-06 13:13:49 +08:00
|
|
|
d_model: int
|
|
|
|
n_layers: int
|
|
|
|
mlp_hidden_size: int
|
|
|
|
n_heads: int
|
|
|
|
vocab_size: int
|
|
|
|
embedding_size: int
|
|
|
|
rope_theta: float = 10000
|
|
|
|
rope_traditional: bool = False
|
2024-02-06 21:27:05 +08:00
|
|
|
mlp_ratio: int = 4
|
|
|
|
weight_tying: bool = False
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
self.mlp_hidden_size = (
|
|
|
|
self.mlp_hidden_size
|
|
|
|
if self.mlp_hidden_size is not None
|
|
|
|
else self.mlp_ratio * self.d_model
|
|
|
|
)
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.n_heads = args.n_heads
|
|
|
|
dim = args.d_model
|
|
|
|
|
|
|
|
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
|
|
|
|
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
|
|
|
|
|
2024-03-23 22:13:51 +08:00
|
|
|
self.att_norm = nn.LayerNorm(dim, affine=False)
|
|
|
|
self.ff_norm = nn.LayerNorm(dim, affine=False)
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
head_dim = dim // self.n_heads
|
|
|
|
self.scale = head_dim**-0.5
|
|
|
|
|
|
|
|
self.att_proj = nn.Linear(dim, 3 * dim, bias=False)
|
|
|
|
self.attn_out = nn.Linear(dim, dim, bias=False)
|
|
|
|
|
|
|
|
self.rope = nn.RoPE(
|
|
|
|
head_dim,
|
|
|
|
traditional=args.rope_traditional,
|
|
|
|
base=args.rope_theta,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
def attend(
|
|
|
|
self,
|
|
|
|
x: mx.array,
|
|
|
|
mask: Optional[mx.array] = None,
|
2024-10-08 11:45:51 +08:00
|
|
|
cache: Optional[Any] = None,
|
2024-02-06 13:13:49 +08:00
|
|
|
) -> mx.array:
|
|
|
|
B, L, D = x.shape
|
|
|
|
|
|
|
|
queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1)
|
|
|
|
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
|
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
|
|
|
|
if cache is not None:
|
2024-05-08 23:18:13 +08:00
|
|
|
queries = self.rope(queries, offset=cache.offset)
|
|
|
|
keys = self.rope(keys, offset=cache.offset)
|
|
|
|
keys, values = cache.update_and_fetch(keys, values)
|
2024-02-06 13:13:49 +08:00
|
|
|
else:
|
|
|
|
queries = self.rope(queries)
|
|
|
|
keys = self.rope(keys)
|
|
|
|
|
|
|
|
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
|
|
|
if mask is not None:
|
|
|
|
scores += mask
|
|
|
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
|
|
|
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
2024-05-08 23:18:13 +08:00
|
|
|
return self.attn_out(output)
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
x: mx.array,
|
|
|
|
mask: Optional[mx.array] = None,
|
2024-10-08 11:45:51 +08:00
|
|
|
cache: Optional[Any] = None,
|
2024-02-06 13:13:49 +08:00
|
|
|
) -> mx.array:
|
2024-05-08 23:18:13 +08:00
|
|
|
r = self.attend(self.att_norm(x), mask, cache)
|
2024-02-06 13:13:49 +08:00
|
|
|
h = x + r
|
|
|
|
|
|
|
|
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
2024-02-06 21:27:05 +08:00
|
|
|
|
2024-02-06 13:13:49 +08:00
|
|
|
out = h + self.ff_out(nn.silu(x2) * x1)
|
2024-05-08 23:18:13 +08:00
|
|
|
return out
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.n_layers = args.n_layers
|
2024-02-06 21:27:05 +08:00
|
|
|
self.weight_tying = args.weight_tying
|
|
|
|
|
2024-02-06 13:13:49 +08:00
|
|
|
self.wte = nn.Embedding(args.embedding_size, args.d_model)
|
|
|
|
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
2024-02-06 21:27:05 +08:00
|
|
|
if not self.weight_tying:
|
|
|
|
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
2024-03-23 22:13:51 +08:00
|
|
|
self.norm = nn.LayerNorm(args.d_model, affine=False)
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
inputs: mx.array,
|
|
|
|
cache=None,
|
|
|
|
):
|
|
|
|
h = self.wte(inputs)
|
|
|
|
|
2024-07-26 07:45:22 +08:00
|
|
|
mask = create_attention_mask(h, cache)
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
if cache is None:
|
|
|
|
cache = [None] * len(self.blocks)
|
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
for block, c in zip(self.blocks, cache):
|
|
|
|
h = block(h, mask, c)
|
2024-02-06 13:13:49 +08:00
|
|
|
|
2024-02-06 21:27:05 +08:00
|
|
|
h = self.norm(h)
|
|
|
|
|
|
|
|
if self.weight_tying:
|
2024-04-19 09:16:10 +08:00
|
|
|
return self.wte.as_linear(h), cache
|
2024-02-06 21:27:05 +08:00
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
return self.ff_out(h)
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
class OlmoModel(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.transformer = Transformer(args)
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
inputs: mx.array,
|
|
|
|
cache=None,
|
|
|
|
):
|
|
|
|
return self.transformer(inputs, cache)
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
2024-02-13 02:51:02 +08:00
|
|
|
self.model_type = args.model_type
|
2024-02-06 13:13:49 +08:00
|
|
|
self.model = OlmoModel(args)
|
2024-05-08 23:18:13 +08:00
|
|
|
self.args = args
|
2024-02-06 13:13:49 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
inputs: mx.array,
|
|
|
|
cache=None,
|
|
|
|
):
|
|
|
|
return self.model(inputs, cache)
|
2024-02-20 12:37:15 +08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def layers(self):
|
|
|
|
return self.model.transformer.blocks
|