Merge pull request #90 from bofenghuang/fix-fp16

Fix whisper fp16 inference
This commit is contained in:
Awni Hannun
2023-12-13 07:29:10 -08:00
committed by GitHub

View File

@@ -117,7 +117,7 @@ class ResidualAttentionBlock(nn.Module):
if self.cross_attn: if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
x += y x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
return x, (kv, cross_kv) return x, (kv, cross_kv)
@@ -134,8 +134,8 @@ class AudioEncoder(nn.Module):
self.ln_post = LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
def __call__(self, x): def __call__(self, x):
x = nn.gelu(self.conv1(x)) x = nn.gelu(self.conv1(x)).astype(x.dtype)
x = nn.gelu(self.conv2(x)) x = nn.gelu(self.conv2(x)).astype(x.dtype)
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
x = x + self._positional_embedding x = x + self._positional_embedding