mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix fp16
This commit is contained in:
parent
74c4ed40d2
commit
4b1a06c0cb
@ -117,7 +117,7 @@ class ResidualAttentionBlock(nn.Module):
|
||||
if self.cross_attn:
|
||||
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
||||
x += y
|
||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
|
||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
|
||||
return x, (kv, cross_kv)
|
||||
|
||||
|
||||
@ -134,8 +134,8 @@ class AudioEncoder(nn.Module):
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.gelu(self.conv1(x))
|
||||
x = nn.gelu(self.conv2(x))
|
||||
x = nn.gelu(self.conv1(x)).astype(x.dtype)
|
||||
x = nn.gelu(self.conv2(x)).astype(x.dtype)
|
||||
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
||||
x = x + self._positional_embedding
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user