diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index 62e43de3..bca69946 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -117,7 +117,7 @@ class ResidualAttentionBlock(nn.Module): if self.cross_attn: y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) x += y - x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) + x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) return x, (kv, cross_kv) @@ -134,8 +134,8 @@ class AudioEncoder(nn.Module): self.ln_post = LayerNorm(n_state) def __call__(self, x): - x = nn.gelu(self.conv1(x)) - x = nn.gelu(self.conv2(x)) + x = nn.gelu(self.conv1(x)).astype(x.dtype) + x = nn.gelu(self.conv2(x)).astype(x.dtype) assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" x = x + self._positional_embedding