diff --git a/mistral/mistral.py b/mistral/mistral.py index e846ebfd..6a6447bc 100644 --- a/mistral/mistral.py +++ b/mistral/mistral.py @@ -92,14 +92,6 @@ class Attention(nn.Module): scores += mask scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - # queries = queries.reshape(B, self.n_kv_heads, self.repeats, L, -1) - # scores = (queries * self.scale) @ mx.expand_dims(keys.transpose(0, 1, 3, 2), 2) - # if mask is not None: - # scores += mask - # scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - # output = (scores @ mx.expand_dims(values, 2)).reshape(B, self.n_heads, L, -1) - # output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output), (keys, values) @@ -146,13 +138,9 @@ class Mistral(nn.Module): self.vocab_size = args.vocab_size self.n_layers = args.n_layers assert self.vocab_size > 0 - self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) - self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] - self.norm = RMSNorm(args.dim, eps=args.norm_eps) - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) def __call__(