Check edge case handling in row reduce med kernel (#858)

This commit is contained in:
Jagrit Digani 2024-03-20 11:37:58 -07:00 committed by GitHub
parent cec8661113
commit b219d12a6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 6 deletions

View File

@ -108,15 +108,17 @@ template <typename T, typename U, typename Op>
const short i_ed = short(reduction_size);
const short i_jump = reductions_per_thread;
for(short r = r_st; r < r_ed; r += r_jump) {
if(r_st < r_jump) {
for(short r = r_st; r < r_ed; r += r_jump) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
for(short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
for(short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
}

View File

@ -590,6 +590,18 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape))
self.assertTrue(np.all(sum_npy == sum_mlx))
x_npy = (
np.arange(3 * 2 * 3 * 3 * 3 * 3)
.reshape(3, 2, 3, 3, 3, 3)
.astype(np.float32)
)
x_mlx = mx.array(x_npy)
y_mlx = x_mlx.sum(axis=(0, 1, 3, 4, 5))
y_npy = x_npy.sum(axis=(0, 1, 3, 4, 5))
self.assertTrue(np.array_equal(y_mlx, y_npy))
def test_prod(self):
x = mx.array(
[