mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	fix (#1566)
This commit is contained in:
		| @@ -75,7 +75,7 @@ void CustomKernel::eval_gpu( | |||||||
|   MTL::Size group_dims = MTL::Size(tx, ty, tz); |   MTL::Size group_dims = MTL::Size(tx, ty, tz); | ||||||
|   const auto [gx, gy, gz] = grid_; |   const auto [gx, gy, gz] = grid_; | ||||||
|   MTL::Size grid_dims = MTL::Size(gx, gy, gz); |   MTL::Size grid_dims = MTL::Size(gx, gy, gz); | ||||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); |   compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||||
|  |  | ||||||
|   d.add_temporaries(std::move(copies), s.index); |   d.add_temporaries(std::move(copies), s.index); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -738,7 +738,7 @@ void fft_op( | |||||||
|     auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); |     auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); | ||||||
|     auto grid_dims = |     auto grid_dims = | ||||||
|         MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); |         MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); | ||||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); |     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   d.add_temporaries(std::move(copies), s.index); |   d.add_temporaries(std::move(copies), s.index); | ||||||
|   | |||||||
| @@ -144,7 +144,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|  |  | ||||||
|     MTL::Size group_dims = MTL::Size(1, threads_per, 1); |     MTL::Size group_dims = MTL::Size(1, threads_per, 1); | ||||||
|     MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); |     MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); | ||||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); |     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   if (m > 1) { |   if (m > 1) { | ||||||
|   | |||||||
| @@ -139,7 +139,7 @@ void sdpa_full_self_attention_metal( | |||||||
|   MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); |   MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); | ||||||
|   MTL::Size group_dims = MTL::Size(32, wm, wn); |   MTL::Size group_dims = MTL::Size(32, wm, wn); | ||||||
|  |  | ||||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); |   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||||
| } | } | ||||||
|  |  | ||||||
| void sdpa_vector( | void sdpa_vector( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun