mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	split logsumexp
This commit is contained in:
		| @@ -5,28 +5,33 @@ template <typename T, typename AccT = float, int N_READS = 4> | ||||
|     const device T* in, | ||||
|     device T* out, | ||||
|     constant int& axis_size, | ||||
|     uint gid [[threadgroup_position_in_grid]], | ||||
|     uint _lid [[thread_position_in_threadgroup]], | ||||
|     uint2 gid [[threadgroup_position_in_grid]], | ||||
|     uint2 tid [[thread_position_in_grid]], | ||||
|     uint2 grid_dim [[threads_per_grid]], | ||||
|     uint2 _lid [[thread_position_in_threadgroup]], | ||||
|     uint simd_lane_id [[thread_index_in_simdgroup]], | ||||
|     uint simd_group_id [[simdgroup_index_in_threadgroup]]) { | ||||
|   int lid = _lid; | ||||
|   int lid = _lid.x; | ||||
|  | ||||
|   constexpr int SIMD_SIZE = 32; | ||||
|   constexpr int elem_per_group = SIMD_SIZE * 32 * N_READS; | ||||
|  | ||||
|   threadgroup AccT local_max[SIMD_SIZE]; | ||||
|   threadgroup AccT local_normalizer[SIMD_SIZE]; | ||||
|  | ||||
|   AccT ld[N_READS]; | ||||
|  | ||||
|   in += gid * size_t(axis_size) + lid * N_READS; | ||||
|   if (lid * N_READS + N_READS <= axis_size) { | ||||
|   const int axis_offset = tid.y * elem_per_group; | ||||
|   in += gid.x * size_t(axis_size) + lid * N_READS + axis_offset; | ||||
|   if (axis_offset + lid * N_READS + N_READS <= axis_size) { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
|       ld[i] = AccT(in[i]); | ||||
|     } | ||||
|   } else { | ||||
|     for (int i = 0; i < N_READS; i++) { | ||||
|       ld[i] = | ||||
|           ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min; | ||||
|       ld[i] = ((axis_offset + lid * N_READS + i) < axis_size) | ||||
|           ? AccT(in[i]) | ||||
|           : Limits<AccT>::min; | ||||
|     } | ||||
|   } | ||||
|   if (simd_group_id == 0) { | ||||
| @@ -55,6 +60,7 @@ template <typename T, typename AccT = float, int N_READS = 4> | ||||
|   maxval = local_max[0]; | ||||
|  | ||||
|   // Compute exp(x_i - maxval) and store the partial sums in local_normalizer | ||||
|   out += gid.x * grid_dim.y + tid.y; | ||||
|   AccT normalizer = 0; | ||||
|   for (int i = 0; i < N_READS; i++) { | ||||
|     normalizer += fast::exp(ld[i] - maxval); | ||||
| @@ -67,7 +73,7 @@ template <typename T, typename AccT = float, int N_READS = 4> | ||||
|   if (simd_group_id == 0) { | ||||
|     normalizer = simd_sum(local_normalizer[simd_lane_id]); | ||||
|     if (simd_lane_id == 0) { | ||||
|       out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); | ||||
|       out[0] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|   | ||||
| @@ -62,15 +62,37 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   const int n_reads = 4; | ||||
|   const int looped_limit = LOGSUMEXP_LOOPED_LIMIT; | ||||
|  | ||||
|   std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_"; | ||||
|   bool split = n_rows < 4 && axis_size > 4 * looped_limit; | ||||
|   bool looped = !split && axis_size > looped_limit; | ||||
|   std::string kernel_name = looped ? "looped_" : "block_"; | ||||
|   kernel_name += "logsumexp_"; | ||||
|   kernel_name += type_to_name(out); | ||||
|  | ||||
|   auto kernel = get_logsumexp_kernel(d, kernel_name, out); | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   if (split) { | ||||
|     auto tmp_size = ceildiv(axis_size, looped_limit); | ||||
|     auto tmp_shape = Shape{n_rows, static_cast<int>(tmp_size)}; | ||||
|     array tmp(tmp_shape, in.dtype(), nullptr, {}); | ||||
|     tmp.set_data(allocator::malloc(tmp.nbytes())); | ||||
|     size_t threadgroup_size = 1024; | ||||
|     assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); | ||||
|     size_t n_threads = n_rows * threadgroup_size; | ||||
|     auto grid_dims = MTL::Size(n_threads, tmp_size, 1); | ||||
|     auto group_dims = MTL::Size(threadgroup_size, 1, 1); | ||||
|     compute_encoder.set_compute_pipeline_state(kernel); | ||||
|     compute_encoder.set_input_array(in, 0); | ||||
|     compute_encoder.set_output_array(tmp, 1); | ||||
|     compute_encoder.set_bytes(axis_size, 2); | ||||
|     compute_encoder.dispatch_threads(grid_dims, group_dims); | ||||
|     d.add_temporary(tmp, s.index); | ||||
|     in = tmp; | ||||
|     axis_size = tmp_size; | ||||
|   } | ||||
|  | ||||
|   { | ||||
|     MTL::Size grid_dims, group_dims; | ||||
|     if (axis_size <= looped_limit) { | ||||
|     if (!looped) { | ||||
|       size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; | ||||
|       size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; | ||||
|       size_t threadgroup_size = simd_size * simds_needed; | ||||
|   | ||||
| @@ -760,6 +760,10 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) | ||||
|         self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) | ||||
|  | ||||
|         # Even larger | ||||
|         x = mx.random.uniform(shape=(4 * 4096 + 3,)) | ||||
|         self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) | ||||
|  | ||||
|     def test_mean(self): | ||||
|         x = mx.array( | ||||
|             [ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun