mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Merge pull request #90 from bofenghuang/fix-fp16
Fix whisper fp16 inference
This commit is contained in:
commit
700b67fa3a
@ -117,7 +117,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
||||||
x += y
|
x += y
|
||||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
|
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
|
||||||
return x, (kv, cross_kv)
|
return x, (kv, cross_kv)
|
||||||
|
|
||||||
|
|
||||||
@ -134,8 +134,8 @@ class AudioEncoder(nn.Module):
|
|||||||
self.ln_post = LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = nn.gelu(self.conv1(x))
|
x = nn.gelu(self.conv1(x)).astype(x.dtype)
|
||||||
x = nn.gelu(self.conv2(x))
|
x = nn.gelu(self.conv2(x)).astype(x.dtype)
|
||||||
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
||||||
x = x + self._positional_embedding
|
x = x + self._positional_embedding
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user