mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
nits
This commit is contained in:
parent
992f5cc0fa
commit
d4f7ecd851
@ -5,7 +5,6 @@ from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
||||
STD = 0.02
|
||||
|
||||
|
||||
class FeedForward(nn.Sequential):
|
||||
@ -25,9 +24,7 @@ class Attention(nn.Module):
|
||||
self.heads = heads
|
||||
self.scale = dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(dim, dim), nn.Dropout(dropout)
|
||||
)
|
||||
self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))
|
||||
|
||||
def __call__(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
@ -133,9 +130,11 @@ class KWT(nn.Module):
|
||||
in_channels, dim, kernel_size=patch_res, stride=patch_res
|
||||
)
|
||||
self.pos_embedding = mx.random.truncated_normal(
|
||||
-1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim)
|
||||
-0.01,
|
||||
0.01,
|
||||
(self.num_patches + 1, dim),
|
||||
)
|
||||
self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim))
|
||||
self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
||||
self.pool = pool
|
||||
@ -152,8 +151,6 @@ class KWT(nn.Module):
|
||||
x = x.reshape(x.shape[0], -1, self.dim)
|
||||
assert x.shape[1] == self.num_patches
|
||||
|
||||
# x = x + self.pos_embedding[:, 1:, :]
|
||||
|
||||
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
|
||||
x = mx.concatenate((cls_tokens, x), axis=1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user