mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Check edge case handling in row reduce med kernel (#858)
This commit is contained in:
		@@ -108,15 +108,17 @@ template <typename T, typename U, typename Op>
 | 
			
		||||
    const short i_ed = short(reduction_size);
 | 
			
		||||
    const short i_jump = reductions_per_thread;
 | 
			
		||||
 | 
			
		||||
    for(short r = r_st; r < r_ed; r += r_jump) {
 | 
			
		||||
    if(r_st < r_jump) {
 | 
			
		||||
      for(short r = r_st; r < r_ed; r += r_jump) {
 | 
			
		||||
 | 
			
		||||
      uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
 | 
			
		||||
      const device T * in_row = in + in_idx;
 | 
			
		||||
        uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
 | 
			
		||||
        const device T * in_row = in + in_idx;
 | 
			
		||||
 | 
			
		||||
        for(short i = i_st; i < i_ed; i += i_jump) {
 | 
			
		||||
          total_val = op(static_cast<U>(in_row[i]), total_val);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
      for(short i = i_st; i < i_ed; i += i_jump) {
 | 
			
		||||
        total_val = op(static_cast<U>(in_row[i]), total_val);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user