mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Check edge case handling in row reduce med kernel (#858)
This commit is contained in:
parent
cec8661113
commit
b219d12a6b
@ -108,15 +108,17 @@ template <typename T, typename U, typename Op>
|
|||||||
const short i_ed = short(reduction_size);
|
const short i_ed = short(reduction_size);
|
||||||
const short i_jump = reductions_per_thread;
|
const short i_jump = reductions_per_thread;
|
||||||
|
|
||||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
if(r_st < r_jump) {
|
||||||
|
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||||
|
|
||||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||||
const device T * in_row = in + in_idx;
|
const device T * in_row = in + in_idx;
|
||||||
|
|
||||||
|
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||||
|
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||||
|
}
|
||||||
|
|
||||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
|
||||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -590,6 +590,18 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape))
|
self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape))
|
||||||
self.assertTrue(np.all(sum_npy == sum_mlx))
|
self.assertTrue(np.all(sum_npy == sum_mlx))
|
||||||
|
|
||||||
|
x_npy = (
|
||||||
|
np.arange(3 * 2 * 3 * 3 * 3 * 3)
|
||||||
|
.reshape(3, 2, 3, 3, 3, 3)
|
||||||
|
.astype(np.float32)
|
||||||
|
)
|
||||||
|
x_mlx = mx.array(x_npy)
|
||||||
|
|
||||||
|
y_mlx = x_mlx.sum(axis=(0, 1, 3, 4, 5))
|
||||||
|
y_npy = x_npy.sum(axis=(0, 1, 3, 4, 5))
|
||||||
|
|
||||||
|
self.assertTrue(np.array_equal(y_mlx, y_npy))
|
||||||
|
|
||||||
def test_prod(self):
|
def test_prod(self):
|
||||||
x = mx.array(
|
x = mx.array(
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user