mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
use fast group norm
This commit is contained in:
@@ -30,7 +30,7 @@ class EncodecConv1d(nn.Module):
|
||||
in_channels, out_channels, kernel_size, stride, dilation=dilation
|
||||
)
|
||||
if self.norm_type == "time_group_norm":
|
||||
self.norm = nn.GroupNorm(1, out_channels)
|
||||
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
||||
|
||||
self.stride = stride
|
||||
|
||||
|
Reference in New Issue
Block a user