mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
194 lines
6.1 KiB
Python
194 lines
6.1 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
from dataclasses import dataclass
|
|
import math
|
|
from typing import Optional, Tuple, List
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from mlx.utils import tree_map, tree_unflatten
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs:
|
|
dim: int
|
|
n_layers: int
|
|
head_dim: int
|
|
hidden_dim: int
|
|
n_heads: int
|
|
n_kv_heads: int
|
|
norm_eps: float
|
|
vocab_size: int
|
|
|
|
|
|
class LoRALinear(nn.Module):
|
|
@staticmethod
|
|
def from_linear(linear: nn.Linear, rank: int = 8):
|
|
output_dims, input_dims = linear.weight.shape
|
|
lora_lin = LoRALinear(input_dims, output_dims, rank)
|
|
lora_lin.linear = linear
|
|
return lora_lin
|
|
|
|
def __init__(
|
|
self, input_dims: int, output_dims: int, lora_rank: int = 8, bias: bool = False
|
|
):
|
|
super().__init__()
|
|
|
|
# Regular linear layer weights
|
|
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
|
|
|
# Low rank lora weights
|
|
scale = 1 / math.sqrt(input_dims)
|
|
self.lora_a = mx.random.uniform(
|
|
low=-scale,
|
|
high=scale,
|
|
shape=(input_dims, lora_rank),
|
|
)
|
|
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
|
|
|
|
def __call__(self, x):
|
|
y = self.linear(x)
|
|
z = (x @ self.lora_a) @ self.lora_b
|
|
return y + 2.0 * z
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dims: int, eps: float = 1e-5):
|
|
super().__init__()
|
|
self.weight = mx.ones((dims,))
|
|
self.eps = eps
|
|
|
|
def _norm(self, x):
|
|
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
|
|
|
def __call__(self, x):
|
|
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
|
return self.weight * output
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
|
|
self.n_heads: int = args.n_heads
|
|
self.n_kv_heads: int = args.n_kv_heads
|
|
|
|
self.repeats = self.n_heads // self.n_kv_heads
|
|
|
|
self.scale = self.args.head_dim**-0.5
|
|
|
|
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
|
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
|
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
|
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
|
self.rope = nn.RoPE(args.head_dim, traditional=True)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> mx.array:
|
|
B, L, D = x.shape
|
|
|
|
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
|
|
|
# Prepare the queries, keys and values for the attention computation
|
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
|
|
def repeat(a):
|
|
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
return a.reshape([B, self.n_heads, L, -1])
|
|
|
|
if self.repeats > 1:
|
|
keys, values = map(repeat, (keys, values))
|
|
|
|
if cache is not None:
|
|
key_cache, value_cache = cache
|
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
values = mx.concatenate([value_cache, values], axis=2)
|
|
else:
|
|
queries = self.rope(queries)
|
|
keys = self.rope(keys)
|
|
|
|
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
|
if mask is not None:
|
|
scores += mask
|
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
|
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
return self.wo(output), (keys, values)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
|
|
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
|
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
|
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
|
|
|
def __call__(self, x) -> mx.array:
|
|
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.n_heads = args.n_heads
|
|
self.dim = args.dim
|
|
self.attention = Attention(args)
|
|
self.feed_forward = FeedForward(args=args)
|
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
self.args = args
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
) -> mx.array:
|
|
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
|
h = x + r
|
|
r = self.feed_forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out, cache
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.vocab_size = args.vocab_size
|
|
self.n_layers = args.n_layers
|
|
assert self.vocab_size > 0
|
|
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
|
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
|
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
|
|
|
def __call__(
|
|
self,
|
|
inputs: mx.array,
|
|
cache=None,
|
|
):
|
|
h = self.tok_embeddings(inputs)
|
|
|
|
mask = None
|
|
if h.shape[1] > 1:
|
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
|
mask = mask.astype(h.dtype)
|
|
|
|
if cache is None:
|
|
cache = [None] * len(self.layers)
|
|
|
|
for e, layer in enumerate(self.layers):
|
|
h, cache[e] = layer(h, mask, cache[e])
|
|
|
|
return self.output(self.norm(h)), cache
|