Lazy import + refactor Lora layer addition (#426)

* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
This commit is contained in:
Awni Hannun
2024-02-12 10:51:02 -08:00
committed by GitHub
parent 4576946151
commit d4666615bb
15 changed files with 127 additions and 393 deletions

View File

@@ -9,6 +9,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
@@ -18,7 +19,6 @@ class ModelArgs(BaseModelArgs):
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@@ -190,6 +190,7 @@ class LlamaModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

View File

@@ -21,9 +21,9 @@ class ModelArgs(BaseModelArgs):
num_local_experts: int = 8
rms_norm_eps: float = 1e-5
vocab_size: int
model_type: str
rope_theta: float = 1e6
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@@ -252,6 +252,7 @@ class MixtralModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = MixtralModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

View File

@@ -7,18 +7,25 @@ import mlx.nn as nn
from .base import BaseModelArgs
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
d_model: int
n_layers: int
mlp_hidden_size: int
n_heads: int
vocab_size: int
embedding_size: int
model_type: str
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
mlp_ratio: int = 4
weight_tying: bool = False
@@ -162,11 +169,7 @@ class OlmoModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
self.model_type = args.model_type
self.model = OlmoModel(args)
def __call__(

View File

@@ -10,6 +10,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
max_position_embeddings: int = 2048
vocab_size: int = 51200
hidden_size: int = 2560
@@ -163,6 +164,7 @@ class PhiModel(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.model = PhiModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

View File

@@ -0,0 +1,218 @@
import glob
import inspect
import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
model_type: str
max_sequence_length: int = 2048
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
num_layers: int = 32
rotary_dim: int = 32
num_experts_per_tok: int = 2
num_local_experts: int = 4
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.Wqkv = nn.Linear(dims, 3 * dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class MOE(nn.Module):
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)]
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1))[:, :ne]
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
if self.training:
ys = []
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.mlp):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[..., None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims)
self.moe = MOE(config, dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache)
ff_h = self.moe(h)
return attn_h + ff_h + x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embd = Embd(config)
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache):
x = self.embd(x)
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
return x, cache
class Embd(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
def __call__(self, x):
return self.wte(x)
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.ln = LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):
return self.linear(self.ln(inputs))
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache

View File

@@ -1,126 +1,25 @@
from typing import Any, List, NamedTuple, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from transformers import PretrainedConfig
from .base import BaseModelArgs
class DecoderInput(NamedTuple):
hidden_states: mx.array
position_ids: mx.array
attention_mask: Optional[mx.array] = None
past_key_values: Optional[List[mx.array]] = None
output_hidden_states: Optional[bool] = False
output_attentions: Optional[bool] = False
use_cache: Optional[bool] = False
gradient_checkpointing: bool = False
class DecoderOutput(NamedTuple):
hidden_states: mx.array
all_hidden_states: Optional[Tuple[mx.array, ...]]
all_self_attns: Optional[Tuple[mx.array, ...]]
next_decoder_cache: Optional[Tuple[mx.array, ...]]
class ModelArgs(PretrainedConfig): # type: ignore
model_type: str = "plamo"
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 4096,
intermediate_size: int = 13312,
num_hidden_layers: int = 32,
num_attention_heads: int = 32,
max_position_embeddings: int = 2048,
initializer_range: float = 0.02,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
tokenizer_class: str = "PlamoTokenizer",
pad_token_id: Optional[int] = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
n_shared_head: int = 8,
tie_word_embeddings: bool = False,
**kwargs: Any,
) -> None:
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.n_shared_head = n_shared_head
super().__init__(
tokenizer_class=tokenizer_class,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class RotaryEmbedding:
def __init__(
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
) -> None:
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / mx.power(
self.base, mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim
)
self.cos_cached = mx.zeros((1, 1, max_position_embeddings, dim))
self.sin_cached = mx.zeros((1, 1, max_position_embeddings, dim))
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int) -> None:
self.max_seq_len_cached = seq_len
t = mx.arange(self.max_seq_len_cached) # type: ignore
freqs = mx.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = mx.concatenate((freqs, freqs), axis=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
def __call__(self, x: mx.array, seq_len: int) -> Tuple[mx.array, mx.array]:
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
return (
self.cos_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
self.sin_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
)
def _rotate_half(x: mx.array) -> mx.array:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate((-x2, x1), axis=-1)
def _rotary_pos_emb(
x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array
) -> mx.array:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = mx.squeeze(cos, (0, 1)) # [seq_len, dim]
sin = mx.squeeze(sin, (0, 1)) # [seq_len, dim]
cos = cos[position_ids][:, None] # [bs, 1, seq_len, dim]
sin = sin[position_ids][:, None] # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (_rotate_half(x) * sin)
return x_embed
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
n_shared_head: int = (8,)
rope_theta: float = 10000
rope_traditional: bool = False
class RMSNorm(nn.Module):
@@ -143,7 +42,6 @@ class Attention(nn.Module):
self.config = config
self.hidden_size = config.hidden_size
head_dim = self.hidden_size // config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.q_num_heads = config.num_attention_heads
self.qk_dim = self.v_dim = head_dim
@@ -165,15 +63,17 @@ class Attention(nn.Module):
self.o_proj = nn.Linear(
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
)
self.rotary_emb = RotaryEmbedding(
self.qk_dim, max_position_embeddings=self.max_position_embeddings
self.rotary_emb = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=1.0,
)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
bsz, q_len, _ = hidden_states.shape
@@ -204,13 +104,11 @@ class Attention(nn.Module):
key_states = _expand_kv(key_states)
value_states = _expand_kv(value_states)
kv_seq_len = key_states.shape[-2]
kv_seq_len = 0
if cache is not None:
kv_seq_len += cache[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
assert position_ids is not None
query_states = _rotary_pos_emb(query_states, cos, sin, position_ids)
key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
query_states = self.rotary_emb(query_states, offset=kv_seq_len)
key_states = self.rotary_emb(key_states, offset=kv_seq_len)
if cache is not None:
# reuse k, v, self_attention
@@ -235,10 +133,9 @@ class MLP(nn.Module):
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
class PlamoDecoderLayer(nn.Module):
@@ -254,7 +151,6 @@ class PlamoDecoderLayer(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[Any, ...]:
# from LlamaDecoder
@@ -266,18 +162,14 @@ class PlamoDecoderLayer(nn.Module):
hidden_states_sa, cache = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
cache=cache,
)
# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
# Residual ("Parallel Layers" is used here, which is different from the normal residual connection)
# See "GPT-NeoX-20B: An Open-Source Autoregressive Language Model" for Parallel Layers
hidden_states = residual + hidden_states_sa + hidden_states_mlp
return hidden_states, cache # type: ignore
return hidden_states, cache
class PlamoDecoder(nn.Module):
@@ -289,24 +181,14 @@ class PlamoDecoder(nn.Module):
class PlamoModel(nn.Module):
config_class = ModelArgs
_no_split_modules: List[str]
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PlamoDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = PlamoDecoder(config) # type: ignore
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
def __call__(
self,
@@ -326,10 +208,9 @@ class PlamoModel(nn.Module):
else:
if cache[0] is not None:
past_key_values_length = cache[0][0].shape[2]
position_ids = _create_position_ids(h.shape[1], past_key_values_length)
for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, position_ids, cache[e])
h, c = layer(h, mask, cache[e])
if cache is not None:
cache[e] = c
else:
@@ -338,22 +219,13 @@ class PlamoModel(nn.Module):
return self.norm(h), cache
def _create_position_ids(seq_length: int, past_key_values_length: int = 0) -> mx.array:
# create position_ids on the fly for batch generation
position_ids = mx.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=mx.int64
)
position_ids = position_ids[None, ...].reshape(-1, seq_length)
return position_ids
class Model(nn.Module):
def __init__(self, config: PretrainedConfig) -> None:
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model = PlamoModel(config)
self.model_type = args.model_type
self.model = PlamoModel(args)
self.lm_head: nn.Module = nn.Linear(
config.hidden_size, config.vocab_size, bias=False
args.hidden_size, args.vocab_size, bias=False
)
def __call__(

View File

@@ -9,6 +9,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 2048
num_attention_heads: int = 16
num_hidden_layers: int = 24
@@ -160,6 +161,7 @@ class QwenModel(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.transformer = QwenModel(config)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=not config.no_bias

View File

@@ -9,6 +9,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
@@ -190,6 +191,7 @@ class Qwen2Model(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Qwen2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

View File

@@ -11,6 +11,7 @@ from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
max_position_embeddings: int
model_type: str
vocab_size: int
hidden_size: int
num_attention_heads: int
@@ -169,6 +170,7 @@ class StableLM(nn.Module):
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.model = StableLM(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)