mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-24 05:01:16 +08:00
phi-2 draft
This commit is contained in:
parent
9c7e996ff0
commit
a466cc5191
24
phi2/README.md
Normal file
24
phi2/README.md
Normal file
@ -0,0 +1,24 @@
|
||||
# Phi-2
|
||||
|
||||
Phi-2 is a 2.7B parameter model released by Microsoft and trained on a mixture of GPT-4 outputs and clean web-text.
|
||||
Its performance theoretically rivals much, much stronger models.
|
||||
|
||||
## Downloading and Converting Weights
|
||||
|
||||
To download and convert the model:
|
||||
|
||||
```sh
|
||||
python phi2/convert.py
|
||||
```
|
||||
|
||||
That will fill in `weights/phi-2.npz`.
|
||||
|
||||
## Running the Model
|
||||
|
||||
🚧 (Not yet done) To run the model:
|
||||
|
||||
```sh
|
||||
python phi2/generate.py
|
||||
```
|
||||
|
||||
Layer-by-layer forward pass outputs are currently shown in the outputs.txt files.
|
0
phi2/__init__.py
Normal file
0
phi2/__init__.py
Normal file
67
phi2/convert.py
Normal file
67
phi2/convert.py
Normal file
@ -0,0 +1,67 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import numpy
|
||||
|
||||
|
||||
def split_attention_matrix(state_dict, key) -> dict:
|
||||
# "transformer.h.0.mixer"
|
||||
_, model_dim = state_dict[key + ".weight"].shape
|
||||
# (3 * model_dim, model_dim)
|
||||
Wqkv_weight_key = key + ".weight"
|
||||
Wq_weight = state_dict[Wqkv_weight_key][:model_dim, :]
|
||||
Wk_weight = state_dict[Wqkv_weight_key][model_dim : 2 * model_dim, :]
|
||||
Wv_weight = state_dict[Wqkv_weight_key][2 * model_dim :, :]
|
||||
|
||||
# (3 * model_dim)
|
||||
Wqkv_bias_key = key + ".bias"
|
||||
Wq_bias = state_dict[Wqkv_bias_key][:model_dim]
|
||||
Wk_bias = state_dict[Wqkv_bias_key][model_dim : 2 * model_dim]
|
||||
Wv_bias = state_dict[Wqkv_bias_key][2 * model_dim :]
|
||||
|
||||
out_key = key.replace("mixer.Wqkv", "self_attention")
|
||||
|
||||
return {
|
||||
out_key + ".query_proj.weight": Wq_weight,
|
||||
out_key + ".query_proj.bias": Wq_bias,
|
||||
out_key + ".key_proj.weight": Wk_weight,
|
||||
out_key + ".key_proj.bias": Wk_bias,
|
||||
out_key + ".value_proj.weight": Wv_weight,
|
||||
out_key + ".value_proj.bias": Wv_bias,
|
||||
}
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
if "wte.weight" in key:
|
||||
key = "wte.weight"
|
||||
|
||||
if ".mlp" in key:
|
||||
key = key.replace(".mlp", "")
|
||||
|
||||
if ".mixer.out_proj" in key:
|
||||
key = key.replace(".mixer", ".self_attention")
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def convert():
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
for key in keys:
|
||||
if ".mixer.Wqkv.weight" not in key:
|
||||
continue
|
||||
key_stub = key.rstrip(".weight")
|
||||
state_dict.update(split_attention_matrix(state_dict, key_stub))
|
||||
|
||||
del state_dict[key_stub + ".weight"]
|
||||
del state_dict[key_stub + ".bias"]
|
||||
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
numpy.savez("weights/phi-2.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
23
phi2/hf_model.py
Normal file
23
phi2/hf_model.py
Normal file
@ -0,0 +1,23 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
|
||||
inputs = tokenizer(
|
||||
'''def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""''',
|
||||
return_tensors="pt",
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
print(model(**inputs))
|
||||
|
||||
# outputs = model.generate(**inputs, max_length=200)
|
||||
# text = tokenizer.batch_decode(outputs)[0]
|
||||
# print(text)
|
232
phi2/model.py
Normal file
232
phi2/model.py
Normal file
@ -0,0 +1,232 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from mlx.utils import tree_unflatten, tree_map
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
max_sequence_length: int = 2048
|
||||
num_vocab: int = 51200
|
||||
model_dim: int = 2560
|
||||
num_heads: int = 32
|
||||
num_layers: int = 32
|
||||
rotary_dim: int = 32
|
||||
|
||||
|
||||
class NewGELUActivation(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __call__(self, input: mx.array) -> mx.array:
|
||||
return (
|
||||
0.5
|
||||
* input
|
||||
* (
|
||||
1.0
|
||||
+ mx.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * (input**3)))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(dims // num_heads, traditional=True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=bias)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=bias)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=bias)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=bias)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
# Note that we return the keys and values to possibly be used as a cache
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.self_attention = RoPEAttention(dims, num_heads, bias=True)
|
||||
self.ln = nn.LayerNorm(dims)
|
||||
self.fc1 = nn.Linear(dims, mlp_dims)
|
||||
self.fc2 = nn.Linear(mlp_dims, dims)
|
||||
self.act = NewGELUActivation()
|
||||
|
||||
def __call__(self, x, x_mask):
|
||||
residual = x
|
||||
hidden_states = self.ln(x)
|
||||
attn_outputs, _ = self.self_attention(
|
||||
hidden_states, hidden_states, hidden_states, x_mask
|
||||
)
|
||||
ff_hidden_states = self.fc2(self.act(self.fc1(hidden_states)))
|
||||
|
||||
hidden_states = attn_outputs + ff_hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)]
|
||||
|
||||
def __call__(self, x, x_mask):
|
||||
for layer in self.h:
|
||||
x = layer(x, x_mask)
|
||||
return x
|
||||
|
||||
|
||||
class Phi2(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||
self.transformer = TransformerDecoder(
|
||||
num_layers=config.num_layers,
|
||||
dims=config.model_dim,
|
||||
num_heads=config.num_heads,
|
||||
)
|
||||
|
||||
self.lm_head = LanguageModelingHead(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
attention_mask: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.wte(input_ids)
|
||||
|
||||
if attention_mask is not None:
|
||||
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||
attention_mask = mx.log(attention_mask)
|
||||
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||
else:
|
||||
attention_mask = nn.MultiHeadAttention.create_additive_causal_mask(
|
||||
x.shape[1]
|
||||
)
|
||||
|
||||
y = self.transformer(x, attention_mask)
|
||||
return self.lm_head(y)
|
||||
|
||||
def generate(self, input_ids, temp=1.0):
|
||||
cache = input_ids.tolist()
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(input_ids.shape[1])
|
||||
mask = mask.astype(self.wte.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same way as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.wte(input_ids)
|
||||
# for l in self.layers:
|
||||
# x, c = l(x, mask=mask)
|
||||
# cache.append(c) # <--- we store the per layer cache in a
|
||||
# simple python list
|
||||
x = self.transformer(x, mask)
|
||||
y = self.lm_head(x[:, -1]) # <--- we only care about the last logits
|
||||
# that generate the next token
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
yield y
|
||||
cache += [y.item()]
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = self.wte(mx.array(cache))
|
||||
x = self.transformer(x, mask)
|
||||
y = self.lm_head(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
cache += [y[0].item()]
|
||||
|
||||
yield y
|
||||
|
||||
|
||||
class LanguageModelingHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
self.ln = nn.LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = Phi2(ModelArgs())
|
||||
|
||||
weights = mx.load("weights/phi-2.npz")
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: mx.array(p), weights)
|
||||
|
||||
model.update(weights)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
tokens = tokenizer(
|
||||
'''def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""''',
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
|
||||
print(
|
||||
'''def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""'''
|
||||
)
|
||||
for output in model.generate(**tokens):
|
||||
print(tokenizer.decode(output.item()))
|
63
phi2/phi2_outputs.txt
Normal file
63
phi2/phi2_outputs.txt
Normal file
@ -0,0 +1,63 @@
|
||||
(HF) Output of Embeddings
|
||||
|
||||
tensor([[[-0.0353, 0.0045, 0.0208, ..., -0.0117, 0.0041, 0.0075],
|
||||
[-0.0172, 0.0236, -0.0051, ..., 0.0141, 0.0115, 0.0058],
|
||||
[-0.0148, 0.0043, -0.0252, ..., 0.0179, 0.0025, -0.0008],
|
||||
...,
|
||||
[ 0.0003, 0.0051, 0.0002, ..., 0.0043, 0.0075, 0.0049],
|
||||
[-0.0110, 0.0472, 0.0030, ..., 0.0098, -0.0075, 0.0146],
|
||||
[-0.0085, -0.0219, -0.0016, ..., -0.0059, 0.0109, -0.0016]]],
|
||||
device='cuda:0', dtype=torch.float16, grad_fn=<EmbeddingBackward0>)
|
||||
|
||||
(MLX) Output of Embeddings
|
||||
|
||||
array([[[-0.0352783, 0.00445175, 0.020813, ..., -0.0117188, 0.00411606, 0.00748444],
|
||||
[-0.0171509, 0.0236053, -0.00508881, ..., 0.0141144, 0.0115204, 0.00582504],
|
||||
[-0.0147858, 0.00426102, -0.0252075, ..., 0.0179443, 0.0024662, -0.00076437],
|
||||
...,
|
||||
[0.000337124, 0.00508499, 0.000193119, ..., 0.00427628, 0.00753403, 0.00492477],
|
||||
[-0.0110092, 0.0472107, 0.00295448, ..., 0.00982666, -0.00747681, 0.0145721],
|
||||
[-0.00852203, -0.0218964, -0.00161839, ..., -0.00592422, 0.0108643, -0.00162697]]], dtype=float16)
|
||||
|
||||
(HF) Output of First Attention Layer
|
||||
|
||||
tensor([[[-0.2000, 0.4849, 0.9863, ..., -0.2209, 0.1355, 0.3469],
|
||||
[ 0.4922, -0.3865, 0.8428, ..., 0.5894, -0.0069, -0.5278],
|
||||
[ 0.0902, 0.1028, 0.6826, ..., 0.1394, -0.8145, -0.1880],
|
||||
...,
|
||||
[ 0.2380, 0.0555, -0.3005, ..., 0.0372, -0.0895, 0.0255],
|
||||
[ 0.2512, 0.1949, 0.3401, ..., 0.3625, -0.3103, -0.1064],
|
||||
[-0.0905, 0.0665, 0.5210, ..., -0.0767, -0.2460, -0.1449]]],
|
||||
device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
|
||||
torch.Size([1, 23, 2560])
|
||||
|
||||
(MLX) Output of First Attention Layer
|
||||
|
||||
array([[[-0.199973, 0.485224, 0.987237, ..., -0.220847, 0.13511, 0.346074],
|
||||
[0.44883, -0.271683, 0.877478, ..., 0.653217, -0.0929724, -0.711176],
|
||||
[-0.233398, 5.7824e-05, 0.435001, ..., 0.0504494, -0.623998, -0.438785],
|
||||
...,
|
||||
[0.123587, -0.237459, -0.447518, ..., 0.0653363, -0.0767153, -0.341505],
|
||||
[0.187798, 0.331209, 0.0827338, ..., 0.529453, -0.582141, -0.165316],
|
||||
[-0.413614, 0.134572, 0.685769, ..., 0.0796088, 0.0217719, -0.118885]]], dtype=float32)
|
||||
[1, 23, 2560]
|
||||
|
||||
(HF) Overall Output of Inputs:
|
||||
|
||||
tensor([[[ 6.4688, 5.1016, 1.9658, ..., -2.9043, -2.9043, -2.9043],
|
||||
[ 5.2188, 6.4414, 5.1914, ..., -0.1852, -0.1862, -0.1866],
|
||||
[ 4.3516, 5.3281, 5.9922, ..., -0.3689, -0.3699, -0.3696],
|
||||
...,
|
||||
[10.4141, 11.7031, 12.5859, ..., 0.7778, 0.7769, 0.7754],
|
||||
[10.7188, 11.7891, 13.3125, ..., 1.6123, 1.6113, 1.6104],
|
||||
[10.8047, 12.0234, 12.4375, ..., 0.2321, 0.2314, 0.2317]]],
|
||||
|
||||
(MLX) Overall Output of Inputs:
|
||||
|
||||
array([[[6.46632, 5.10102, 1.96306, ..., -2.90427, -2.90341, -2.90392],
|
||||
[4.5092, 5.90938, 4.98036, ..., -0.411165, -0.412062, -0.412547],
|
||||
[4.34246, 5.7794, 6.13245, ..., -0.40106, -0.402052, -0.401838],
|
||||
...,
|
||||
[6.61827, 10.4022, 12.1672, ..., 0.602787, 0.602138, 0.600666],
|
||||
[7.96546, 12.9569, 14.7947, ..., -0.347764, -0.348587, -0.34937],
|
||||
[8.22272, 10.6631, 11.5968, ..., -1.12037, -1.12025, -1.12152]]], dtype=float32)
|
Loading…
Reference in New Issue
Block a user