mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
export and run llama in C++
This commit is contained in:
@@ -74,9 +74,9 @@ class Attention(nn.Module):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
queries = mx.unflatten(queries, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3)
|
||||
keys = mx.unflatten(keys, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
|
||||
values = mx.unflatten(values, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
@@ -90,7 +90,7 @@ class Attention(nn.Module):
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = output.transpose(0, 2, 1, 3).flatten(-2, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user