use fast group norm

This commit is contained in:
Awni Hannun
2024-09-15 08:47:10 -07:00
parent c3209fd29a
commit 0a73862430

View File

@@ -30,7 +30,7 @@ class EncodecConv1d(nn.Module):
in_channels, out_channels, kernel_size, stride, dilation=dilation
)
if self.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels)
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
self.stride = stride