mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 08:43:26 +08:00
Fix to use repeat instead of tile
This commit is contained in:
parent
d7426c7750
commit
21c0abaf23
@ -498,11 +498,11 @@ def selective_state_update_ref(
|
|||||||
dt = nn.softplus(dt) if dt_softplus else dt
|
dt = nn.softplus(dt) if dt_softplus else dt
|
||||||
dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
|
dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
|
||||||
B = mx.reshape(
|
B = mx.reshape(
|
||||||
mx.tile(mx.expand_dims(B, axis=2), (1, 1, nheads // ngroups, 1)),
|
mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2),
|
||||||
(batch, nheads, dstate),
|
(batch, nheads, dstate),
|
||||||
) # (batch, nheads, dstate)
|
) # (batch, nheads, dstate)
|
||||||
C = mx.reshape(
|
C = mx.reshape(
|
||||||
mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)),
|
mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2),
|
||||||
(batch, nheads, dstate),
|
(batch, nheads, dstate),
|
||||||
) # (batch, nheads, dstate)
|
) # (batch, nheads, dstate)
|
||||||
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate)
|
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate)
|
||||||
@ -1070,8 +1070,8 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# expand shared kv
|
# expand shared kv
|
||||||
assert self.k_num_heads == self.v_num_heads
|
assert self.k_num_heads == self.v_num_heads
|
||||||
key_states = mx.tile(key_states, (1, self.n_group, 1, 1))
|
key_states = mx.repeat(key_states, self.n_group, 1)
|
||||||
value_states = mx.tile(value_states, (1, self.n_group, 1, 1))
|
value_states = mx.repeat(value_states, self.n_group, 1)
|
||||||
|
|
||||||
full_attn = self.layer_idx in self.config.full_attention_idx
|
full_attn = self.layer_idx in self.config.full_attention_idx
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user