Merge pull request #90 from bofenghuang/fix-fp16

Fix whisper fp16 inference
This commit is contained in:
Awni Hannun 2023-12-13 07:29:10 -08:00 committed by GitHub
commit 700b67fa3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -117,7 +117,7 @@ class ResidualAttentionBlock(nn.Module):
if self.cross_attn: if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
x += y x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
return x, (kv, cross_kv) return x, (kv, cross_kv)
@ -134,8 +134,8 @@ class AudioEncoder(nn.Module):
self.ln_post = LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
def __call__(self, x): def __call__(self, x):
x = nn.gelu(self.conv1(x)) x = nn.gelu(self.conv1(x)).astype(x.dtype)
x = nn.gelu(self.conv2(x)) x = nn.gelu(self.conv2(x)).astype(x.dtype)
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
x = x + self._positional_embedding x = x + self._positional_embedding