mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
comments
This commit is contained in:
parent
e4333a3325
commit
bcede4bc0c
@ -92,14 +92,6 @@ class Attention(nn.Module):
|
|||||||
scores += mask
|
scores += mask
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
# queries = queries.reshape(B, self.n_kv_heads, self.repeats, L, -1)
|
|
||||||
# scores = (queries * self.scale) @ mx.expand_dims(keys.transpose(0, 1, 3, 2), 2)
|
|
||||||
# if mask is not None:
|
|
||||||
# scores += mask
|
|
||||||
# scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
|
||||||
# output = (scores @ mx.expand_dims(values, 2)).reshape(B, self.n_heads, L, -1)
|
|
||||||
# output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
return self.wo(output), (keys, values)
|
return self.wo(output), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
@ -146,13 +138,9 @@ class Mistral(nn.Module):
|
|||||||
self.vocab_size = args.vocab_size
|
self.vocab_size = args.vocab_size
|
||||||
self.n_layers = args.n_layers
|
self.n_layers = args.n_layers
|
||||||
assert self.vocab_size > 0
|
assert self.vocab_size > 0
|
||||||
|
|
||||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||||
|
|
||||||
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||||
|
|
||||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
|
||||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
Loading…
Reference in New Issue
Block a user