diff --git a/speechcommands/kwt.py b/speechcommands/kwt.py index 40b4c71d..f68fd632 100644 --- a/speechcommands/kwt.py +++ b/speechcommands/kwt.py @@ -5,7 +5,6 @@ from mlx.utils import tree_flatten __all__ = ["KWT", "kwt1", "kwt2", "kwt3"] -STD = 0.02 class FeedForward(nn.Sequential): @@ -25,9 +24,7 @@ class Attention(nn.Module): self.heads = heads self.scale = dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=False) - self.out = nn.Sequential( - nn.Linear(dim, dim), nn.Dropout(dropout) - ) + self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout)) def __call__(self, x): b, n, _, h = *x.shape, self.heads @@ -133,9 +130,11 @@ class KWT(nn.Module): in_channels, dim, kernel_size=patch_res, stride=patch_res ) self.pos_embedding = mx.random.truncated_normal( - -1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim) + -0.01, + 0.01, + (self.num_patches + 1, dim), ) - self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim)) + self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) self.pool = pool @@ -152,8 +151,6 @@ class KWT(nn.Module): x = x.reshape(x.shape[0], -1, self.dim) assert x.shape[1] == self.num_patches - # x = x + self.pos_embedding[:, 1:, :] - cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim)) x = mx.concatenate((cls_tokens, x), axis=1)