This commit is contained in:
Awni Hannun 2023-12-05 21:36:47 -08:00
parent e4333a3325
commit bcede4bc0c

View File

@ -92,14 +92,6 @@ class Attention(nn.Module):
scores += mask scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# queries = queries.reshape(B, self.n_kv_heads, self.repeats, L, -1)
# scores = (queries * self.scale) @ mx.expand_dims(keys.transpose(0, 1, 3, 2), 2)
# if mask is not None:
# scores += mask
# scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
# output = (scores @ mx.expand_dims(values, 2)).reshape(B, self.n_heads, L, -1)
# output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values) return self.wo(output), (keys, values)
@ -146,13 +138,9 @@ class Mistral(nn.Module):
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
self.n_layers = args.n_layers self.n_layers = args.n_layers
assert self.vocab_size > 0 assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False) self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__( def __call__(