Fix T5.__call__

This commit is contained in:
Juarez Bochi 2023-12-18 08:00:01 -05:00
parent 34843ddeb2
commit 689eda9937
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -314,7 +314,7 @@ class T5(nn.Module):
inputs: mx.array,
decoder_inputs: mx.array,
):
return decode(decoder_inputs, encode(inputs))[0]
return self.decode(decoder_inputs, self.encode(inputs))[0]
def generate(