mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -234,7 +234,7 @@ void scan_dispatch( | ||||
|       auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; | ||||
|       auto init = (issubdtype(input.dtype(), floating)) | ||||
|           ? static_cast<U>(-std::numeric_limits<float>::infinity()) | ||||
|           : std::numeric_limits<U>::max(); | ||||
|           : std::numeric_limits<U>::min(); | ||||
|       auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init); | ||||
|       auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init); | ||||
|       scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive); | ||||
|   | ||||
| @@ -309,6 +309,7 @@ template < | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     // Share the prefix | ||||
|     if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { | ||||
|   | ||||
| @@ -65,16 +65,16 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|  | ||||
|     // Compute the thread grid | ||||
|     int n_reads = (in.itemsize() <= 4) ? 4 : 2; | ||||
|     int elements_per_simd = n_reads * 32; | ||||
|     constexpr int simd_size = 32; | ||||
|     int elements_per_simd = n_reads * simd_size; | ||||
|     int thread_groups = in.size() / size; | ||||
|     int thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); | ||||
|     if (size < n_reads * 1024) { | ||||
|       thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) * | ||||
|           elements_per_simd; | ||||
|     } else if (size < n_reads * 2048) { | ||||
|     if (size <= n_reads * 1024) { | ||||
|       thread_group_size = | ||||
|           ((size / 2 + elements_per_simd - 1) / elements_per_simd) * | ||||
|           elements_per_simd; | ||||
|           ((size + elements_per_simd - 1) / elements_per_simd) * simd_size; | ||||
|     } else if (size <= n_reads * 2048) { | ||||
|       thread_group_size = | ||||
|           ((size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size; | ||||
|     } | ||||
|     thread_group_size = std::min( | ||||
|         thread_group_size, | ||||
|   | ||||
| @@ -1678,7 +1678,9 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                 c_mlx = mxop(a_mlx, axis=axis) | ||||
|                 self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) | ||||
|  | ||||
|         a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100) | ||||
|         for op in ["cumsum", "cumprod", "cummax", "cummin"]: | ||||
|             mxop = getattr(mx, op) | ||||
|             c1 = mxop(a_mlx, axis=2) | ||||
|             c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False) | ||||
|             self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:])) | ||||
| @@ -1719,6 +1721,18 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             out = mx.cumsum(a_t, axis=-1) | ||||
|             expected = (mat_t * a_t[:, None, :]).sum(axis=-1) | ||||
|             self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3)) | ||||
|         sizes = [1023, 1024, 1025, 2047, 2048, 2049] | ||||
|         for s in sizes: | ||||
|             a = mx.ones((s,), mx.int32) | ||||
|             out = mx.cumsum(a) | ||||
|             expected = mx.arange(1, s + 1, dtype=mx.int32) | ||||
|             self.assertTrue(mx.array_equal(expected, out)) | ||||
|  | ||||
|             # non-contiguous scan | ||||
|             a = mx.ones((s, 2), mx.int32) | ||||
|             out = mx.cumsum(a, axis=0) | ||||
|             expected = mx.repeat(expected[:, None], 2, axis=1) | ||||
|             self.assertTrue(mx.array_equal(expected, out)) | ||||
|  | ||||
|     def test_squeeze_expand(self): | ||||
|         a = mx.zeros((2, 1, 2, 1)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun