mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	| @@ -234,7 +234,7 @@ void scan_dispatch( | |||||||
|       auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; |       auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; | ||||||
|       auto init = (issubdtype(input.dtype(), floating)) |       auto init = (issubdtype(input.dtype(), floating)) | ||||||
|           ? static_cast<U>(-std::numeric_limits<float>::infinity()) |           ? static_cast<U>(-std::numeric_limits<float>::infinity()) | ||||||
|           : std::numeric_limits<U>::max(); |           : std::numeric_limits<U>::min(); | ||||||
|       auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init); |       auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init); | ||||||
|       auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init); |       auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init); | ||||||
|       scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive); |       scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive); | ||||||
|   | |||||||
| @@ -309,6 +309,7 @@ template < | |||||||
|         } |         } | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|     // Share the prefix |     // Share the prefix | ||||||
|     if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { |     if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { | ||||||
|   | |||||||
| @@ -65,16 +65,16 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|  |  | ||||||
|     // Compute the thread grid |     // Compute the thread grid | ||||||
|     int n_reads = (in.itemsize() <= 4) ? 4 : 2; |     int n_reads = (in.itemsize() <= 4) ? 4 : 2; | ||||||
|     int elements_per_simd = n_reads * 32; |     constexpr int simd_size = 32; | ||||||
|  |     int elements_per_simd = n_reads * simd_size; | ||||||
|     int thread_groups = in.size() / size; |     int thread_groups = in.size() / size; | ||||||
|     int thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); |     int thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||||
|     if (size < n_reads * 1024) { |     if (size <= n_reads * 1024) { | ||||||
|       thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) * |  | ||||||
|           elements_per_simd; |  | ||||||
|     } else if (size < n_reads * 2048) { |  | ||||||
|       thread_group_size = |       thread_group_size = | ||||||
|           ((size / 2 + elements_per_simd - 1) / elements_per_simd) * |           ((size + elements_per_simd - 1) / elements_per_simd) * simd_size; | ||||||
|           elements_per_simd; |     } else if (size <= n_reads * 2048) { | ||||||
|  |       thread_group_size = | ||||||
|  |           ((size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size; | ||||||
|     } |     } | ||||||
|     thread_group_size = std::min( |     thread_group_size = std::min( | ||||||
|         thread_group_size, |         thread_group_size, | ||||||
|   | |||||||
| @@ -1678,7 +1678,9 @@ class TestOps(mlx_tests.MLXTestCase): | |||||||
|                 c_mlx = mxop(a_mlx, axis=axis) |                 c_mlx = mxop(a_mlx, axis=axis) | ||||||
|                 self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) |                 self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) | ||||||
|  |  | ||||||
|  |         a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100) | ||||||
|         for op in ["cumsum", "cumprod", "cummax", "cummin"]: |         for op in ["cumsum", "cumprod", "cummax", "cummin"]: | ||||||
|  |             mxop = getattr(mx, op) | ||||||
|             c1 = mxop(a_mlx, axis=2) |             c1 = mxop(a_mlx, axis=2) | ||||||
|             c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False) |             c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False) | ||||||
|             self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:])) |             self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:])) | ||||||
| @@ -1719,6 +1721,18 @@ class TestOps(mlx_tests.MLXTestCase): | |||||||
|             out = mx.cumsum(a_t, axis=-1) |             out = mx.cumsum(a_t, axis=-1) | ||||||
|             expected = (mat_t * a_t[:, None, :]).sum(axis=-1) |             expected = (mat_t * a_t[:, None, :]).sum(axis=-1) | ||||||
|             self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3)) |             self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3)) | ||||||
|  |         sizes = [1023, 1024, 1025, 2047, 2048, 2049] | ||||||
|  |         for s in sizes: | ||||||
|  |             a = mx.ones((s,), mx.int32) | ||||||
|  |             out = mx.cumsum(a) | ||||||
|  |             expected = mx.arange(1, s + 1, dtype=mx.int32) | ||||||
|  |             self.assertTrue(mx.array_equal(expected, out)) | ||||||
|  |  | ||||||
|  |             # non-contiguous scan | ||||||
|  |             a = mx.ones((s, 2), mx.int32) | ||||||
|  |             out = mx.cumsum(a, axis=0) | ||||||
|  |             expected = mx.repeat(expected[:, None], 2, axis=1) | ||||||
|  |             self.assertTrue(mx.array_equal(expected, out)) | ||||||
|  |  | ||||||
|     def test_squeeze_expand(self): |     def test_squeeze_expand(self): | ||||||
|         a = mx.zeros((2, 1, 2, 1)) |         a = mx.zeros((2, 1, 2, 1)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun