This commit is contained in:
Goekdeniz-Guelmez 2024-12-27 15:37:41 +01:00
parent 2ed51946ab
commit 3384d38a83

View File

@ -98,10 +98,7 @@ class Mamba2Block(nn.Module):
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
if args.intermediate_size == None:
self.d_inner = int(self.expand * self.d_model)
else:
self.d_inner = args.intermediate_size
self.d_inner = args.intermediate_size or int(self.expand * self.d_model)
self.n_groups = args.n_groups
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads