mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
No scaling, no encoder mask
This commit is contained in:
parent
64e7eaccb8
commit
d12db65eeb
8
t5/t5.py
8
t5/t5.py
@ -125,8 +125,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scores = queries @ keys
|
||||||
scores = (queries * scale) @ keys
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
@ -274,12 +274,14 @@ class T5(nn.Module):
|
|||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
x = self.wte(inputs)
|
x = self.wte(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
y = self.encoder(x, mask=None) #, cache)
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if x.shape[1] > 1:
|
if x.shape[1] > 1:
|
||||||
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
y = self.encoder(x, mask) #, cache)
|
|
||||||
# y, cache = self.decoder(x, mask, cache)
|
# y, cache = self.decoder(x, mask, cache)
|
||||||
# return self.lm_head(y), cache
|
# return self.lm_head(y), cache
|
||||||
return y #, cache
|
return y #, cache
|
||||||
|
Loading…
Reference in New Issue
Block a user