mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
Fix to use repeat instead of tile
This commit is contained in:
parent
d7426c7750
commit
21c0abaf23
@ -498,11 +498,11 @@ def selective_state_update_ref(
|
||||
dt = nn.softplus(dt) if dt_softplus else dt
|
||||
dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
|
||||
B = mx.reshape(
|
||||
mx.tile(mx.expand_dims(B, axis=2), (1, 1, nheads // ngroups, 1)),
|
||||
mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
C = mx.reshape(
|
||||
mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)),
|
||||
mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2),
|
||||
(batch, nheads, dstate),
|
||||
) # (batch, nheads, dstate)
|
||||
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate)
|
||||
@ -1070,8 +1070,8 @@ class Attention(nn.Module):
|
||||
|
||||
# expand shared kv
|
||||
assert self.k_num_heads == self.v_num_heads
|
||||
key_states = mx.tile(key_states, (1, self.n_group, 1, 1))
|
||||
value_states = mx.tile(value_states, (1, self.n_group, 1, 1))
|
||||
key_states = mx.repeat(key_states, self.n_group, 1)
|
||||
value_states = mx.repeat(value_states, self.n_group, 1)
|
||||
|
||||
full_attn = self.layer_idx in self.config.full_attention_idx
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user