mlx-examples/llms/mlx_lm/models/gpt_bigcode.py
Awni Hannun fca087be49
More cache improvements (#1015)
* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-10-07 20:45:51 -07:00

187 lines
5.2 KiB
Python

# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_embd: int
n_layer: int
n_inner: int
n_head: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
multi_query: bool = True
attention_bias: bool = True
mlp_bias: bool = True
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = 1 if self.multi_query else self.n_head
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = dim = args.n_embd
self.n_heads = n_heads = args.n_head
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
self.head_dim = head_dim = dim // n_heads
self.kv_dim = n_kv_heads * head_dim
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.n_embd
hidden_dim = args.n_inner
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out
class GPTBigCodeModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None,
):
B, L = inputs.shape
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
return self.ln_f(hidden_states)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = GPTBigCodeModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.h