This commit is contained in:
Awni Hannun 2023-12-19 06:39:17 -08:00
parent 992f5cc0fa
commit d4f7ecd851

View File

@ -5,7 +5,6 @@ from mlx.utils import tree_flatten
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"] __all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
STD = 0.02
class FeedForward(nn.Sequential): class FeedForward(nn.Sequential):
@ -25,9 +24,7 @@ class Attention(nn.Module):
self.heads = heads self.heads = heads
self.scale = dim**-0.5 self.scale = dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False) self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.out = nn.Sequential( self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))
nn.Linear(dim, dim), nn.Dropout(dropout)
)
def __call__(self, x): def __call__(self, x):
b, n, _, h = *x.shape, self.heads b, n, _, h = *x.shape, self.heads
@ -133,9 +130,11 @@ class KWT(nn.Module):
in_channels, dim, kernel_size=patch_res, stride=patch_res in_channels, dim, kernel_size=patch_res, stride=patch_res
) )
self.pos_embedding = mx.random.truncated_normal( self.pos_embedding = mx.random.truncated_normal(
-1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim) -0.01,
0.01,
(self.num_patches + 1, dim),
) )
self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim)) self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,))
self.dropout = nn.Dropout(emb_dropout) self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.pool = pool self.pool = pool
@ -152,8 +151,6 @@ class KWT(nn.Module):
x = x.reshape(x.shape[0], -1, self.dim) x = x.reshape(x.shape[0], -1, self.dim)
assert x.shape[1] == self.num_patches assert x.shape[1] == self.num_patches
# x = x + self.pos_embedding[:, 1:, :]
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim)) cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
x = mx.concatenate((cls_tokens, x), axis=1) x = mx.concatenate((cls_tokens, x), axis=1)