mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-01 03:28:08 +08:00 
			
		
		
		
	fix RoPE bug + minor updates
This commit is contained in:
		| @@ -41,6 +41,26 @@ class RMSNorm(nn.Module): | ||||
|         return self.weight * output | ||||
|  | ||||
|  | ||||
| class RoPE(nn.RoPE): | ||||
|     def __init__(self, dims: int, traditional: bool = False): | ||||
|         super().__init__(dims, traditional) | ||||
|  | ||||
|     def __call__(self, x, offset: int = 0): | ||||
|         shape = x.shape | ||||
|         x = mx.reshape(x, (-1, shape[-2], shape[-1])) | ||||
|         N = x.shape[1] + offset | ||||
|         costheta, sintheta = RoPE.create_cos_sin_theta( | ||||
|             N, self.dims, offset=offset, base=1000000, dtype=x.dtype | ||||
|         ) | ||||
|  | ||||
|         rope = ( | ||||
|             self._compute_traditional_rope if self.traditional else self._compute_rope | ||||
|         ) | ||||
|         rx = rope(costheta, sintheta, x) | ||||
|  | ||||
|         return mx.reshape(rx, shape) | ||||
|  | ||||
|  | ||||
| class Attention(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
| @@ -57,7 +77,7 @@ class Attention(nn.Module): | ||||
|         self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) | ||||
|         self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) | ||||
|         self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) | ||||
|         self.rope = nn.RoPE(args.head_dim, traditional=True) | ||||
|         self.rope = RoPE(args.head_dim, traditional=True) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
| @@ -126,7 +146,10 @@ class MOEFeedForward(nn.Module): | ||||
|  | ||||
|         gates = self.gate(x) | ||||
|         inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] | ||||
|         scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) | ||||
|         scores = mx.softmax( | ||||
|             mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), | ||||
|             axis=-1, | ||||
|         ).astype(gates.dtype) | ||||
|  | ||||
|         y = [] | ||||
|         for xt, st, it in zip(x, scores, inds.tolist()): | ||||
| @@ -182,8 +205,9 @@ class Mixtral(nn.Module): | ||||
|         h = self.tok_embeddings(inputs) | ||||
|  | ||||
|         mask = None | ||||
|         if h.shape[1] > 1: | ||||
|             mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) | ||||
|         T = h.shape[1] | ||||
|         if T > 1: | ||||
|             mask = nn.MultiHeadAttention.create_additive_causal_mask(T) | ||||
|             mask = mask.astype(h.dtype) | ||||
|  | ||||
|         if cache is None: | ||||
| @@ -192,7 +216,7 @@ class Mixtral(nn.Module): | ||||
|         for e, layer in enumerate(self.layers): | ||||
|             h, cache[e] = layer(h, mask, cache[e]) | ||||
|  | ||||
|         return self.output(self.norm(h)), cache | ||||
|         return self.output(self.norm(h[:, T - 1 : T, :])), cache | ||||
|  | ||||
|  | ||||
| class Tokenizer: | ||||
| @@ -278,7 +302,7 @@ if __name__ == "__main__": | ||||
|         "--temp", | ||||
|         help="The sampling temperature.", | ||||
|         type=float, | ||||
|         default=1.0, | ||||
|         default=0.0, | ||||
|     ) | ||||
|     parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun