Fix to use repeat instead of tile

This commit is contained in:
Shunta Saito 2025-02-23 14:54:23 +09:00
parent d7426c7750
commit 21c0abaf23

View File

@ -498,11 +498,11 @@ def selective_state_update_ref(
dt = nn.softplus(dt) if dt_softplus else dt
dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
B = mx.reshape(
mx.tile(mx.expand_dims(B, axis=2), (1, 1, nheads // ngroups, 1)),
mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2),
(batch, nheads, dstate),
) # (batch, nheads, dstate)
C = mx.reshape(
mx.tile(mx.expand_dims(C, axis=2), (1, 1, nheads // ngroups, 1)),
mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2),
(batch, nheads, dstate),
) # (batch, nheads, dstate)
dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(B, axis=-2) # (batch, nheads, dim, dstate)
@ -1070,8 +1070,8 @@ class Attention(nn.Module):
# expand shared kv
assert self.k_num_heads == self.v_num_heads
key_states = mx.tile(key_states, (1, self.n_group, 1, 1))
value_states = mx.tile(value_states, (1, self.n_group, 1, 1))
key_states = mx.repeat(key_states, self.n_group, 1)
value_states = mx.repeat(value_states, self.n_group, 1)
full_attn = self.layer_idx in self.config.full_attention_idx