mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-13 05:01:12 +08:00
add cache + generation, clean up some stuff
This commit is contained in:
parent
a466cc5191
commit
88d7b67e6e
1
phi2/.gitignore
vendored
Normal file
1
phi2/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
weights.npz
|
@ -60,7 +60,7 @@ def convert():
|
||||
del state_dict[key_stub + ".bias"]
|
||||
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
numpy.savez("weights/phi-2.npz", **weights)
|
||||
numpy.savez("weights.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
177
phi2/model.py
177
phi2/model.py
@ -7,7 +7,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
max_sequence_length: int = 2048
|
||||
@ -18,23 +17,6 @@ class ModelArgs:
|
||||
rotary_dim: int = 32
|
||||
|
||||
|
||||
class NewGELUActivation(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __call__(self, input: mx.array) -> mx.array:
|
||||
return (
|
||||
0.5
|
||||
* input
|
||||
* (
|
||||
1.0
|
||||
+ mx.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * (input**3)))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, bias: bool = True):
|
||||
super().__init__()
|
||||
@ -77,6 +59,7 @@ class RoPEAttention(nn.Module):
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
@ -92,19 +75,13 @@ class ParallelBlock(nn.Module):
|
||||
self.ln = nn.LayerNorm(dims)
|
||||
self.fc1 = nn.Linear(dims, mlp_dims)
|
||||
self.fc2 = nn.Linear(mlp_dims, dims)
|
||||
self.act = NewGELUActivation()
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x, x_mask):
|
||||
residual = x
|
||||
hidden_states = self.ln(x)
|
||||
attn_outputs, _ = self.self_attention(
|
||||
hidden_states, hidden_states, hidden_states, x_mask
|
||||
)
|
||||
ff_hidden_states = self.fc2(self.act(self.fc1(hidden_states)))
|
||||
|
||||
hidden_states = attn_outputs + ff_hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.ln(x)
|
||||
attn_h, cache = self.self_attention(h, h, h, mask, cache)
|
||||
ff_h = self.fc2(self.act(self.fc1(h)))
|
||||
return attn_h + ff_h + x, cache
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
@ -114,10 +91,22 @@ class TransformerDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)]
|
||||
|
||||
def __call__(self, x, x_mask):
|
||||
for layer in self.h:
|
||||
x = layer(x, x_mask)
|
||||
return x
|
||||
def __call__(self, x, mask, cache):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for e, layer in enumerate(self.h):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
return x, cache
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
self.ln = nn.LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
class Phi2(nn.Module):
|
||||
@ -128,77 +117,40 @@ class Phi2(nn.Module):
|
||||
dims=config.model_dim,
|
||||
num_heads=config.num_heads,
|
||||
)
|
||||
|
||||
self.lm_head = LanguageModelingHead(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
attention_mask: mx.array = None,
|
||||
inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.wte(input_ids)
|
||||
x = self.wte(inputs)
|
||||
|
||||
if attention_mask is not None:
|
||||
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||
attention_mask = mx.log(attention_mask)
|
||||
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y, cache = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
attention_mask = nn.MultiHeadAttention.create_additive_causal_mask(
|
||||
x.shape[1]
|
||||
)
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
y = self.transformer(x, attention_mask)
|
||||
return self.lm_head(y)
|
||||
logits, cache = model(prompt)
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
def generate(self, input_ids, temp=1.0):
|
||||
cache = input_ids.tolist()
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(input_ids.shape[1])
|
||||
mask = mask.astype(self.wte.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same way as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.wte(input_ids)
|
||||
# for l in self.layers:
|
||||
# x, c = l(x, mask=mask)
|
||||
# cache.append(c) # <--- we store the per layer cache in a
|
||||
# simple python list
|
||||
x = self.transformer(x, mask)
|
||||
y = self.lm_head(x[:, -1]) # <--- we only care about the last logits
|
||||
# that generate the next token
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
while True:
|
||||
logits, cache = model(y[:, None], cache=cache)
|
||||
y = sample(logits.squeeze(1))
|
||||
yield y
|
||||
cache += [y.item()]
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = self.wte(mx.array(cache))
|
||||
x = self.transformer(x, mask)
|
||||
y = self.lm_head(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
cache += [y[0].item()]
|
||||
|
||||
yield y
|
||||
|
||||
|
||||
class LanguageModelingHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
self.ln = nn.LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -206,27 +158,28 @@ if __name__ == "__main__":
|
||||
|
||||
weights = mx.load("weights/phi-2.npz")
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: mx.array(p), weights)
|
||||
weights = tree_map(lambda p: mx.array(p, mx.float32), weights)
|
||||
|
||||
model.update(weights)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
tokens = tokenizer(
|
||||
'''def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""''',
|
||||
prompt = tokenizer("Write a detailed analogy between mathematics and a lighthouse.",
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)
|
||||
)["input_ids"]
|
||||
|
||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
tokens_per_eval = 1
|
||||
max_tokens = 100
|
||||
|
||||
tokens = []
|
||||
for token, _ in zip(generate(prompt, model), range(max_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % tokens_per_eval) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
print(
|
||||
'''def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""'''
|
||||
)
|
||||
for output in model.generate(**tokens):
|
||||
print(tokenizer.decode(output.item()))
|
||||
|
3
phi2/requirements.txt
Normal file
3
phi2/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
einops
|
||||
mlx
|
||||
transformers
|
Loading…
Reference in New Issue
Block a user