This commit is contained in:
bofenghuang 2023-12-13 11:07:47 +01:00
parent 74c4ed40d2
commit 4b1a06c0cb

View File

@ -117,7 +117,7 @@ class ResidualAttentionBlock(nn.Module):
if self.cross_attn: if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
x += y x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
return x, (kv, cross_kv) return x, (kv, cross_kv)
@ -134,8 +134,8 @@ class AudioEncoder(nn.Module):
self.ln_post = LayerNorm(n_state) self.ln_post = LayerNorm(n_state)
def __call__(self, x): def __call__(self, x):
x = nn.gelu(self.conv1(x)) x = nn.gelu(self.conv1(x)).astype(x.dtype)
x = nn.gelu(self.conv2(x)) x = nn.gelu(self.conv2(x)).astype(x.dtype)
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
x = x + self._positional_embedding x = x + self._positional_embedding