mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Fix decoder mask
This commit is contained in:
parent
f26e81ccc9
commit
7dcf2b688d
12
t5/t5.py
12
t5/t5.py
@ -124,9 +124,12 @@ class MultiHeadAttention(nn.Module):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
# print(f"queries: {queries}, {queries.abs().sum()}")
|
||||
# print(f"keys: {keys}, {keys.abs().sum()}")
|
||||
# print(f"values: {values}, {values.abs().sum()}")
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
@ -289,11 +292,12 @@ class T5(nn.Module):
|
||||
x = self.wte(inputs)
|
||||
y = self.encoder(x, mask=None) # , cache)
|
||||
|
||||
if x.shape[1] > 1 and mask is None:
|
||||
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
decoder_inputs = self.wte(decoder_inputs)
|
||||
decoder_n_tokens = decoder_inputs.shape[1]
|
||||
if decoder_n_tokens > 1 and mask is None:
|
||||
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
decoder_inputs = self.wte(decoder_inputs)
|
||||
y, cache = self.decoder(
|
||||
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
||||
) # , cache)
|
||||
|
Loading…
Reference in New Issue
Block a user