mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Check edge case handling in row reduce med kernel (#858)
This commit is contained in:
@@ -108,15 +108,17 @@ template <typename T, typename U, typename Op>
|
||||
const short i_ed = short(reduction_size);
|
||||
const short i_jump = reductions_per_thread;
|
||||
|
||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||
if(r_st < r_jump) {
|
||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
|
||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user