This commit is contained in:
Goekdeniz-Guelmez
2024-10-22 21:23:47 +02:00
parent 758597eaa8
commit 55485b98e8
3 changed files with 63 additions and 21 deletions

View File

@@ -193,6 +193,7 @@ class Mamba2(nn.Module):
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)