mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix decoder mask
This commit is contained in:
parent
f26e81ccc9
commit
7dcf2b688d
12
t5/t5.py
12
t5/t5.py
@ -124,9 +124,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
keys = self.key_proj(keys)
|
keys = self.key_proj(keys)
|
||||||
values = self.value_proj(values)
|
values = self.value_proj(values)
|
||||||
|
# print(f"queries: {queries}, {queries.abs().sum()}")
|
||||||
|
# print(f"keys: {keys}, {keys.abs().sum()}")
|
||||||
|
# print(f"values: {values}, {values.abs().sum()}")
|
||||||
|
|
||||||
num_heads = self.num_heads
|
num_heads = self.num_heads
|
||||||
B, L, D = queries.shape
|
B, L, _ = queries.shape
|
||||||
_, S, _ = keys.shape
|
_, S, _ = keys.shape
|
||||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||||
@ -289,11 +292,12 @@ class T5(nn.Module):
|
|||||||
x = self.wte(inputs)
|
x = self.wte(inputs)
|
||||||
y = self.encoder(x, mask=None) # , cache)
|
y = self.encoder(x, mask=None) # , cache)
|
||||||
|
|
||||||
if x.shape[1] > 1 and mask is None:
|
decoder_inputs = self.wte(decoder_inputs)
|
||||||
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
decoder_n_tokens = decoder_inputs.shape[1]
|
||||||
|
if decoder_n_tokens > 1 and mask is None:
|
||||||
|
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
decoder_inputs = self.wte(decoder_inputs)
|
|
||||||
y, cache = self.decoder(
|
y, cache = self.decoder(
|
||||||
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
||||||
) # , cache)
|
) # , cache)
|
||||||
|
Loading…
Reference in New Issue
Block a user