diff --git a/phi2/.gitignore b/phi2/.gitignore new file mode 100644 index 00000000..258ec872 --- /dev/null +++ b/phi2/.gitignore @@ -0,0 +1 @@ +weights.npz diff --git a/phi2/convert.py b/phi2/convert.py index cd2f77aa..3c821f69 100644 --- a/phi2/convert.py +++ b/phi2/convert.py @@ -60,7 +60,7 @@ def convert(): del state_dict[key_stub + ".bias"] weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} - numpy.savez("weights/phi-2.npz", **weights) + numpy.savez("weights.npz", **weights) if __name__ == "__main__": diff --git a/phi2/model.py b/phi2/model.py index 991bf193..5253a266 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -7,7 +7,6 @@ import mlx.core as mx import mlx.nn as nn import math - @dataclass class ModelArgs: max_sequence_length: int = 2048 @@ -18,23 +17,6 @@ class ModelArgs: rotary_dim: int = 32 -class NewGELUActivation(nn.Module): - """ - Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see - the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 - """ - - def __call__(self, input: mx.array) -> mx.array: - return ( - 0.5 - * input - * ( - 1.0 - + mx.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * (input**3))) - ) - ) - - class RoPEAttention(nn.Module): def __init__(self, dims: int, num_heads: int, bias: bool = True): super().__init__() @@ -77,6 +59,7 @@ class RoPEAttention(nn.Module): scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) if mask is not None: scores = scores + mask + scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -92,19 +75,13 @@ class ParallelBlock(nn.Module): self.ln = nn.LayerNorm(dims) self.fc1 = nn.Linear(dims, mlp_dims) self.fc2 = nn.Linear(mlp_dims, dims) - self.act = NewGELUActivation() + self.act = nn.GELU(approx="precise") - def __call__(self, x, x_mask): - residual = x - hidden_states = self.ln(x) - attn_outputs, _ = self.self_attention( - hidden_states, hidden_states, hidden_states, x_mask - ) - ff_hidden_states = self.fc2(self.act(self.fc1(hidden_states))) - - hidden_states = attn_outputs + ff_hidden_states + residual - - return hidden_states + def __call__(self, x, mask, cache): + h = self.ln(x) + attn_h, cache = self.self_attention(h, h, h, mask, cache) + ff_h = self.fc2(self.act(self.fc1(h))) + return attn_h + ff_h + x, cache class TransformerDecoder(nn.Module): @@ -114,10 +91,22 @@ class TransformerDecoder(nn.Module): super().__init__() self.h = [ParallelBlock(dims, num_heads, mlp_dims) for i in range(num_layers)] - def __call__(self, x, x_mask): - for layer in self.h: - x = layer(x, x_mask) - return x + def __call__(self, x, mask, cache): + if cache is None: + cache = [None] * len(self.h) + + for e, layer in enumerate(self.h): + x, cache[e] = layer(x, mask, cache[e]) + return x, cache + + +class OutputHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + self.ln = nn.LayerNorm(config.model_dim) + self.linear = nn.Linear(config.model_dim, config.num_vocab) + + def __call__(self, inputs): + return self.linear(self.ln(inputs)) class Phi2(nn.Module): @@ -128,77 +117,40 @@ class Phi2(nn.Module): dims=config.model_dim, num_heads=config.num_heads, ) - - self.lm_head = LanguageModelingHead(config) + self.lm_head = OutputHead(config) def __call__( self, - input_ids: mx.array, - attention_mask: mx.array = None, + inputs: mx.array, + mask: mx.array = None, + cache: mx.array = None, ) -> tuple[mx.array, mx.array]: - x = self.wte(input_ids) + x = self.wte(inputs) - if attention_mask is not None: - # convert 0's to -infs, 1's to 0's, and make it broadcastable - attention_mask = mx.log(attention_mask) - attention_mask = mx.expand_dims(attention_mask, (1, 2)) + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + + y, cache = self.transformer(x, mask, cache) + return self.lm_head(y), cache + + +def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) else: - attention_mask = nn.MultiHeadAttention.create_additive_causal_mask( - x.shape[1] - ) + return mx.random.categorical(logits * (1 / temp)) - y = self.transformer(x, attention_mask) - return self.lm_head(y) + logits, cache = model(prompt) + y = sample(logits[:, -1, :]) + yield y - def generate(self, input_ids, temp=1.0): - cache = input_ids.tolist() - - # Make an additive causal mask. We will need that to process the prompt. - mask = nn.MultiHeadAttention.create_additive_causal_mask(input_ids.shape[1]) - mask = mask.astype(self.wte.weight.dtype) - - # First we process the prompt x the same way as in __call__ but - # save the caches in cache - x = self.wte(input_ids) - # for l in self.layers: - # x, c = l(x, mask=mask) - # cache.append(c) # <--- we store the per layer cache in a - # simple python list - x = self.transformer(x, mask) - y = self.lm_head(x[:, -1]) # <--- we only care about the last logits - # that generate the next token - y = mx.random.categorical(y * (1 / temp)) - - # y now has size [1] - # Since MLX is lazily evaluated nothing is computed yet. - # Calling y.item() would force the computation to happen at - # this point but we can also choose not to do that and let the - # user choose when to start the computation. + while True: + logits, cache = model(y[:, None], cache=cache) + y = sample(logits.squeeze(1)) yield y - cache += [y.item()] - - # Now we parsed the prompt and generated the first token we - # need to feed it back into the model and loop to generate the - # rest. - while True: - # Unsqueezing the last dimension to add a sequence length - # dimension of 1 - x = self.wte(mx.array(cache)) - x = self.transformer(x, mask) - y = self.lm_head(x[:, -1]) - y = mx.random.categorical(y * (1 / temp)) - cache += [y[0].item()] - - yield y - - -class LanguageModelingHead(nn.Module): - def __init__(self, config: ModelArgs) -> None: - self.ln = nn.LayerNorm(config.model_dim) - self.linear = nn.Linear(config.model_dim, config.num_vocab) - - def __call__(self, inputs): - return self.linear(self.ln(inputs)) if __name__ == "__main__": @@ -206,27 +158,28 @@ if __name__ == "__main__": weights = mx.load("weights/phi-2.npz") weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: mx.array(p), weights) + weights = tree_map(lambda p: mx.array(p, mx.float32), weights) model.update(weights) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) - tokens = tokenizer( - '''def print_prime(n): - """ - Print all primes between 1 and n - """''', + prompt = tokenizer("Write a detailed analogy between mathematics and a lighthouse.", return_tensors="np", return_attention_mask=False, - ) + )["input_ids"] - tokens = {key: mx.array(v) for key, v in tokens.items()} + prompt = mx.array(prompt) + + tokens_per_eval = 1 + max_tokens = 100 + + tokens = [] + for token, _ in zip(generate(prompt, model), range(max_tokens)): + tokens.append(token) + + if (len(tokens) % tokens_per_eval) == 0: + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] - print( - '''def print_prime(n): - """ - Print all primes between 1 and n - """''' - ) - for output in model.generate(**tokens): - print(tokenizer.decode(output.item())) diff --git a/phi2/requirements.txt b/phi2/requirements.txt new file mode 100644 index 00000000..6a11f8d2 --- /dev/null +++ b/phi2/requirements.txt @@ -0,0 +1,3 @@ +einops +mlx +transformers