Lazy import + refactor Lora layer addition (#426)

* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
This commit is contained in:
Awni Hannun
2024-02-12 10:51:02 -08:00
committed by GitHub
parent 4576946151
commit d4666615bb
15 changed files with 127 additions and 393 deletions

View File

@@ -1,126 +1,25 @@
from typing import Any, List, NamedTuple, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from transformers import PretrainedConfig
from .base import BaseModelArgs
class DecoderInput(NamedTuple):
hidden_states: mx.array
position_ids: mx.array
attention_mask: Optional[mx.array] = None
past_key_values: Optional[List[mx.array]] = None
output_hidden_states: Optional[bool] = False
output_attentions: Optional[bool] = False
use_cache: Optional[bool] = False
gradient_checkpointing: bool = False
class DecoderOutput(NamedTuple):
hidden_states: mx.array
all_hidden_states: Optional[Tuple[mx.array, ...]]
all_self_attns: Optional[Tuple[mx.array, ...]]
next_decoder_cache: Optional[Tuple[mx.array, ...]]
class ModelArgs(PretrainedConfig): # type: ignore
model_type: str = "plamo"
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 4096,
intermediate_size: int = 13312,
num_hidden_layers: int = 32,
num_attention_heads: int = 32,
max_position_embeddings: int = 2048,
initializer_range: float = 0.02,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
tokenizer_class: str = "PlamoTokenizer",
pad_token_id: Optional[int] = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
n_shared_head: int = 8,
tie_word_embeddings: bool = False,
**kwargs: Any,
) -> None:
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.n_shared_head = n_shared_head
super().__init__(
tokenizer_class=tokenizer_class,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class RotaryEmbedding:
def __init__(
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000
) -> None:
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / mx.power(
self.base, mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim
)
self.cos_cached = mx.zeros((1, 1, max_position_embeddings, dim))
self.sin_cached = mx.zeros((1, 1, max_position_embeddings, dim))
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int) -> None:
self.max_seq_len_cached = seq_len
t = mx.arange(self.max_seq_len_cached) # type: ignore
freqs = mx.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = mx.concatenate((freqs, freqs), axis=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
def __call__(self, x: mx.array, seq_len: int) -> Tuple[mx.array, mx.array]:
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
return (
self.cos_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
self.sin_cached[:, :, :seq_len, ...].astype(x.dtype), # type: ignore
)
def _rotate_half(x: mx.array) -> mx.array:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate((-x2, x1), axis=-1)
def _rotary_pos_emb(
x: mx.array, cos: mx.array, sin: mx.array, position_ids: mx.array
) -> mx.array:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = mx.squeeze(cos, (0, 1)) # [seq_len, dim]
sin = mx.squeeze(sin, (0, 1)) # [seq_len, dim]
cos = cos[position_ids][:, None] # [bs, 1, seq_len, dim]
sin = sin[position_ids][:, None] # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (_rotate_half(x) * sin)
return x_embed
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
n_shared_head: int = (8,)
rope_theta: float = 10000
rope_traditional: bool = False
class RMSNorm(nn.Module):
@@ -143,7 +42,6 @@ class Attention(nn.Module):
self.config = config
self.hidden_size = config.hidden_size
head_dim = self.hidden_size // config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.q_num_heads = config.num_attention_heads
self.qk_dim = self.v_dim = head_dim
@@ -165,15 +63,17 @@ class Attention(nn.Module):
self.o_proj = nn.Linear(
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
)
self.rotary_emb = RotaryEmbedding(
self.qk_dim, max_position_embeddings=self.max_position_embeddings
self.rotary_emb = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=1.0,
)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
bsz, q_len, _ = hidden_states.shape
@@ -204,13 +104,11 @@ class Attention(nn.Module):
key_states = _expand_kv(key_states)
value_states = _expand_kv(value_states)
kv_seq_len = key_states.shape[-2]
kv_seq_len = 0
if cache is not None:
kv_seq_len += cache[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
assert position_ids is not None
query_states = _rotary_pos_emb(query_states, cos, sin, position_ids)
key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
query_states = self.rotary_emb(query_states, offset=kv_seq_len)
key_states = self.rotary_emb(key_states, offset=kv_seq_len)
if cache is not None:
# reuse k, v, self_attention
@@ -235,10 +133,9 @@ class MLP(nn.Module):
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
class PlamoDecoderLayer(nn.Module):
@@ -254,7 +151,6 @@ class PlamoDecoderLayer(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[Any, ...]:
# from LlamaDecoder
@@ -266,18 +162,14 @@ class PlamoDecoderLayer(nn.Module):
hidden_states_sa, cache = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
cache=cache,
)
# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
# Residual ("Parallel Layers" is used here, which is different from the normal residual connection)
# See "GPT-NeoX-20B: An Open-Source Autoregressive Language Model" for Parallel Layers
hidden_states = residual + hidden_states_sa + hidden_states_mlp
return hidden_states, cache # type: ignore
return hidden_states, cache
class PlamoDecoder(nn.Module):
@@ -289,24 +181,14 @@ class PlamoDecoder(nn.Module):
class PlamoModel(nn.Module):
config_class = ModelArgs
_no_split_modules: List[str]
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PlamoDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = PlamoDecoder(config) # type: ignore
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
def __call__(
self,
@@ -326,10 +208,9 @@ class PlamoModel(nn.Module):
else:
if cache[0] is not None:
past_key_values_length = cache[0][0].shape[2]
position_ids = _create_position_ids(h.shape[1], past_key_values_length)
for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, position_ids, cache[e])
h, c = layer(h, mask, cache[e])
if cache is not None:
cache[e] = c
else:
@@ -338,22 +219,13 @@ class PlamoModel(nn.Module):
return self.norm(h), cache
def _create_position_ids(seq_length: int, past_key_values_length: int = 0) -> mx.array:
# create position_ids on the fly for batch generation
position_ids = mx.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=mx.int64
)
position_ids = position_ids[None, ...].reshape(-1, seq_length)
return position_ids
class Model(nn.Module):
def __init__(self, config: PretrainedConfig) -> None:
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model = PlamoModel(config)
self.model_type = args.model_type
self.model = PlamoModel(args)
self.lm_head: nn.Module = nn.Linear(
config.hidden_size, config.vocab_size, bias=False
args.hidden_size, args.vocab_size, bias=False
)
def __call__(