Check edge case handling in row reduce med kernel (#858)

This commit is contained in:
Jagrit Digani
2024-03-20 11:37:58 -07:00
committed by GitHub
parent cec8661113
commit b219d12a6b
2 changed files with 20 additions and 6 deletions

View File

@@ -108,15 +108,17 @@ template <typename T, typename U, typename Op>
const short i_ed = short(reduction_size);
const short i_jump = reductions_per_thread;
for(short r = r_st; r < r_ed; r += r_jump) {
if(r_st < r_jump) {
for(short r = r_st; r < r_ed; r += r_jump) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
for(short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
for(short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
}