mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	[CUDA] Do vectorized store/load in contiguous elementwise ops (#2342)
* Do vectorized store/load in unary ops * Do vectorized store/load in binary_two ops * Do vectorized store/load in copy ops * Do vectorized store/load in ternary ops * Use int32_t for IdxT * binary => binary_two in binary_two.cu * Fix tests on large arrays * Use uint as index type * Contig uses uint as index and non-contig uses int
This commit is contained in:
		@@ -20,15 +20,10 @@ namespace cg = cooperative_groups;
 | 
				
			|||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
 | 
					__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  int remaining = size - index * N_READS;
 | 
					 | 
				
			||||||
  if (remaining <= 0) {
 | 
					 | 
				
			||||||
    return;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (remaining < N_READS) {
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
    for (int i = 0; i < remaining; ++i) {
 | 
					    for (int i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
      IdxT offset = index * N_READS + i;
 | 
					      out[i] = Op{}(a[0], b[0]);
 | 
				
			||||||
      out[offset] = Op{}(a[0], b[0]);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    AlignedVector<Out, N_READS> out_vec;
 | 
					    AlignedVector<Out, N_READS> out_vec;
 | 
				
			||||||
@@ -44,15 +39,10 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
 | 
				
			|||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
 | 
					__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  int remaining = size - index * N_READS;
 | 
					 | 
				
			||||||
  if (remaining <= 0) {
 | 
					 | 
				
			||||||
    return;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (remaining < N_READS) {
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
    for (int i = 0; i < remaining; ++i) {
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
      IdxT offset = index * N_READS + i;
 | 
					      out[i] = Op{}(a[0], b[i]);
 | 
				
			||||||
      out[offset] = Op{}(a[0], b[offset]);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    auto b_vec = load_vector<N_READS>(b, index);
 | 
					    auto b_vec = load_vector<N_READS>(b, index);
 | 
				
			||||||
@@ -70,15 +60,10 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
 | 
				
			|||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
 | 
					__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  int remaining = size - index * N_READS;
 | 
					 | 
				
			||||||
  if (remaining <= 0) {
 | 
					 | 
				
			||||||
    return;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (remaining < N_READS) {
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
    for (int i = 0; i < remaining; ++i) {
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
      IdxT offset = index * N_READS + i;
 | 
					      out[i] = Op{}(a[i], b[0]);
 | 
				
			||||||
      out[offset] = Op{}(a[offset], b[0]);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    auto a_vec = load_vector<N_READS>(a, index);
 | 
					    auto a_vec = load_vector<N_READS>(a, index);
 | 
				
			||||||
@@ -96,15 +81,10 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
 | 
				
			|||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
 | 
					__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  int remaining = size - index * N_READS;
 | 
					 | 
				
			||||||
  if (remaining <= 0) {
 | 
					 | 
				
			||||||
    return;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (remaining < N_READS) {
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
    for (int i = 0; i < remaining; ++i) {
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
      IdxT offset = index * N_READS + i;
 | 
					      out[i] = Op{}(a[i], b[i]);
 | 
				
			||||||
      out[offset] = Op{}(a[offset], b[offset]);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    auto a_vec = load_vector<N_READS>(a, index);
 | 
					    auto a_vec = load_vector<N_READS>(a, index);
 | 
				
			||||||
@@ -267,7 +247,7 @@ void binary_op_gpu_inplace(
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
              });
 | 
					              });
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
          dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
 | 
					          dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
 | 
				
			||||||
            using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
					            using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
				
			||||||
            // TODO: Choose optimized value based on type size.
 | 
					            // TODO: Choose optimized value based on type size.
 | 
				
			||||||
            constexpr int N_READS = 4;
 | 
					            constexpr int N_READS = 4;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,52 +17,119 @@ namespace cu {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
namespace cg = cooperative_groups;
 | 
					namespace cg = cooperative_groups;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename In, typename Out, typename IdxT>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void
 | 
					__global__ void
 | 
				
			||||||
binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
 | 
					binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  if (index < size) {
 | 
					
 | 
				
			||||||
    auto out = Op{}(a[0], b[0]);
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
    out_a[0] = out[0];
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
    out_b[0] = out[1];
 | 
					      auto out = Op{}(a[0], b[0]);
 | 
				
			||||||
 | 
					      out_a[i] = out[0];
 | 
				
			||||||
 | 
					      out_b[i] = out[1];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_a_vec;
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_b_vec;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					    for (int i = 0; i < N_READS; ++i) {
 | 
				
			||||||
 | 
					      auto out = Op{}(a[0], b[0]);
 | 
				
			||||||
 | 
					      out_a_vec.val[i] = out[0];
 | 
				
			||||||
 | 
					      out_b_vec.val[i] = out[1];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out_a, index, out_a_vec);
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out_b, index, out_b_vec);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename In, typename Out, typename IdxT>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void
 | 
					__global__ void
 | 
				
			||||||
binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
 | 
					binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  if (index < size) {
 | 
					
 | 
				
			||||||
    auto out = Op{}(a[0], b[index]);
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
    out_a[index] = out[0];
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
    out_b[index] = out[1];
 | 
					      auto out = Op{}(a[0], b[i]);
 | 
				
			||||||
 | 
					      out_a[i] = out[0];
 | 
				
			||||||
 | 
					      out_b[i] = out[1];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto b_vec = load_vector<N_READS>(b, index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_a_vec;
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_b_vec;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					    for (int i = 0; i < N_READS; ++i) {
 | 
				
			||||||
 | 
					      auto out = Op{}(a[0], b_vec.val[i]);
 | 
				
			||||||
 | 
					      out_a_vec.val[i] = out[0];
 | 
				
			||||||
 | 
					      out_b_vec.val[i] = out[1];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out_a, index, out_a_vec);
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out_b, index, out_b_vec);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename In, typename Out, typename IdxT>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void
 | 
					__global__ void
 | 
				
			||||||
binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
 | 
					binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  if (index < size) {
 | 
					
 | 
				
			||||||
    auto out = Op{}(a[index], b[0]);
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
    out_a[index] = out[0];
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
    out_b[index] = out[1];
 | 
					      auto out = Op{}(a[i], b[0]);
 | 
				
			||||||
 | 
					      out_a[i] = out[0];
 | 
				
			||||||
 | 
					      out_b[i] = out[1];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto a_vec = load_vector<N_READS>(a, index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_a_vec;
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_b_vec;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					    for (int i = 0; i < N_READS; ++i) {
 | 
				
			||||||
 | 
					      auto out = Op{}(a_vec.val[i], b[0]);
 | 
				
			||||||
 | 
					      out_a_vec.val[i] = out[0];
 | 
				
			||||||
 | 
					      out_b_vec.val[i] = out[1];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out_a, index, out_a_vec);
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out_b, index, out_b_vec);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename In, typename Out, typename IdxT>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void
 | 
					__global__ void
 | 
				
			||||||
binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
 | 
					binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  if (index < size) {
 | 
					
 | 
				
			||||||
    auto out = Op{}(a[index], b[index]);
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
    out_a[index] = out[0];
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
    out_b[index] = out[1];
 | 
					      auto out = Op{}(a[i], b[i]);
 | 
				
			||||||
 | 
					      out_a[i] = out[0];
 | 
				
			||||||
 | 
					      out_b[i] = out[1];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto a_vec = load_vector<N_READS>(a, index);
 | 
				
			||||||
 | 
					    auto b_vec = load_vector<N_READS>(b, index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_a_vec;
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_b_vec;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					    for (int i = 0; i < N_READS; ++i) {
 | 
				
			||||||
 | 
					      auto out = Op{}(a_vec.val[i], b_vec.val[i]);
 | 
				
			||||||
 | 
					      out_a_vec.val[i] = out[0];
 | 
				
			||||||
 | 
					      out_b_vec.val[i] = out[1];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out_a, index, out_a_vec);
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out_b, index, out_b_vec);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
 | 
				
			||||||
__global__ void binary_g_nd(
 | 
					__global__ void binary_two_g_nd(
 | 
				
			||||||
    const In* a,
 | 
					    const In* a,
 | 
				
			||||||
    const In* b,
 | 
					    const In* b,
 | 
				
			||||||
    Out* out_a,
 | 
					    Out* out_a,
 | 
				
			||||||
@@ -82,7 +149,7 @@ __global__ void binary_g_nd(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename In, typename Out, typename IdxT>
 | 
					template <typename Op, typename In, typename Out, typename IdxT>
 | 
				
			||||||
__global__ void binary_g(
 | 
					__global__ void binary_two_g(
 | 
				
			||||||
    const In* a,
 | 
					    const In* a,
 | 
				
			||||||
    const In* b,
 | 
					    const In* b,
 | 
				
			||||||
    Out* out_a,
 | 
					    Out* out_a,
 | 
				
			||||||
@@ -103,7 +170,7 @@ __global__ void binary_g(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename In, typename Out>
 | 
					template <typename Op, typename In, typename Out>
 | 
				
			||||||
constexpr bool supports_binary_op() {
 | 
					constexpr bool supports_binary_two_op() {
 | 
				
			||||||
  if (std::is_same_v<Op, DivMod>) {
 | 
					  if (std::is_same_v<Op, DivMod>) {
 | 
				
			||||||
    return std::is_same_v<In, Out> &&
 | 
					    return std::is_same_v<In, Out> &&
 | 
				
			||||||
        (std::is_integral_v<Out> || is_floating_v<Out>);
 | 
					        (std::is_integral_v<Out> || is_floating_v<Out>);
 | 
				
			||||||
@@ -114,7 +181,7 @@ constexpr bool supports_binary_op() {
 | 
				
			|||||||
} // namespace cu
 | 
					} // namespace cu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op>
 | 
					template <typename Op>
 | 
				
			||||||
void binary_op_gpu_inplace(
 | 
					void binary_two_op_gpu_inplace(
 | 
				
			||||||
    const std::vector<array>& inputs,
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
    std::vector<array>& outputs,
 | 
					    std::vector<array>& outputs,
 | 
				
			||||||
    std::string_view op,
 | 
					    std::string_view op,
 | 
				
			||||||
@@ -141,7 +208,7 @@ void binary_op_gpu_inplace(
 | 
				
			|||||||
    dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
 | 
					    dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
 | 
				
			||||||
      using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
 | 
					      using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
 | 
				
			||||||
      using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
 | 
					      using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
 | 
				
			||||||
      if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
 | 
					      if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {
 | 
				
			||||||
        using InType = cuda_type_t<CTYPE_IN>;
 | 
					        using InType = cuda_type_t<CTYPE_IN>;
 | 
				
			||||||
        using OutType = cuda_type_t<CTYPE_OUT>;
 | 
					        using OutType = cuda_type_t<CTYPE_OUT>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -161,8 +228,12 @@ void binary_op_gpu_inplace(
 | 
				
			|||||||
                int ndim = shape.size();
 | 
					                int ndim = shape.size();
 | 
				
			||||||
                if (ndim <= 3) {
 | 
					                if (ndim <= 3) {
 | 
				
			||||||
                  dispatch_1_2_3(ndim, [&](auto dims_constant) {
 | 
					                  dispatch_1_2_3(ndim, [&](auto dims_constant) {
 | 
				
			||||||
                    auto kernel = cu::
 | 
					                    auto kernel = cu::binary_two_g_nd<
 | 
				
			||||||
                        binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
 | 
					                        Op,
 | 
				
			||||||
 | 
					                        InType,
 | 
				
			||||||
 | 
					                        OutType,
 | 
				
			||||||
 | 
					                        IdxT,
 | 
				
			||||||
 | 
					                        dims_constant()>;
 | 
				
			||||||
                    auto [num_blocks, block_dims] =
 | 
					                    auto [num_blocks, block_dims] =
 | 
				
			||||||
                        get_launch_args(kernel, out_a, large());
 | 
					                        get_launch_args(kernel, out_a, large());
 | 
				
			||||||
                    encoder.add_kernel_node(
 | 
					                    encoder.add_kernel_node(
 | 
				
			||||||
@@ -179,7 +250,7 @@ void binary_op_gpu_inplace(
 | 
				
			|||||||
                        const_param<dims_constant()>(b_strides));
 | 
					                        const_param<dims_constant()>(b_strides));
 | 
				
			||||||
                  });
 | 
					                  });
 | 
				
			||||||
                } else {
 | 
					                } else {
 | 
				
			||||||
                  auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
 | 
					                  auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
 | 
				
			||||||
                  auto [num_blocks, block_dims] =
 | 
					                  auto [num_blocks, block_dims] =
 | 
				
			||||||
                      get_launch_args(kernel, out_a, large());
 | 
					                      get_launch_args(kernel, out_a, large());
 | 
				
			||||||
                  encoder.add_kernel_node(
 | 
					                  encoder.add_kernel_node(
 | 
				
			||||||
@@ -198,22 +269,25 @@ void binary_op_gpu_inplace(
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
              });
 | 
					              });
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
          dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
 | 
					          dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
 | 
				
			||||||
            using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
					            using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
				
			||||||
            auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
 | 
					            // TODO: Choose optimized value based on type size.
 | 
				
			||||||
 | 
					            constexpr int N_READS = 4;
 | 
				
			||||||
 | 
					            auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
 | 
				
			||||||
            if (bopt == BinaryOpType::ScalarVector) {
 | 
					            if (bopt == BinaryOpType::ScalarVector) {
 | 
				
			||||||
              kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
 | 
					              kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
 | 
				
			||||||
            } else if (bopt == BinaryOpType::VectorScalar) {
 | 
					            } else if (bopt == BinaryOpType::VectorScalar) {
 | 
				
			||||||
              kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
 | 
					              kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;
 | 
				
			||||||
            } else if (bopt == BinaryOpType::VectorVector) {
 | 
					            } else if (bopt == BinaryOpType::VectorVector) {
 | 
				
			||||||
              kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
 | 
					              kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            auto [num_blocks, block_dims] = get_launch_args(
 | 
					            auto [num_blocks, block_dims] = get_launch_args(
 | 
				
			||||||
                kernel,
 | 
					                kernel,
 | 
				
			||||||
                out_a.data_size(),
 | 
					                out_a.data_size(),
 | 
				
			||||||
                out_a.shape(),
 | 
					                out_a.shape(),
 | 
				
			||||||
                out_a.strides(),
 | 
					                out_a.strides(),
 | 
				
			||||||
                large());
 | 
					                large(),
 | 
				
			||||||
 | 
					                N_READS);
 | 
				
			||||||
            encoder.add_kernel_node(
 | 
					            encoder.add_kernel_node(
 | 
				
			||||||
                kernel,
 | 
					                kernel,
 | 
				
			||||||
                num_blocks,
 | 
					                num_blocks,
 | 
				
			||||||
@@ -237,7 +311,7 @@ void binary_op_gpu_inplace(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op>
 | 
					template <typename Op>
 | 
				
			||||||
void binary_op_gpu(
 | 
					void binary_two_op_gpu(
 | 
				
			||||||
    const std::vector<array>& inputs,
 | 
					    const std::vector<array>& inputs,
 | 
				
			||||||
    std::vector<array>& outputs,
 | 
					    std::vector<array>& outputs,
 | 
				
			||||||
    std::string_view op,
 | 
					    std::string_view op,
 | 
				
			||||||
@@ -247,7 +321,7 @@ void binary_op_gpu(
 | 
				
			|||||||
  auto bopt = get_binary_op_type(a, b);
 | 
					  auto bopt = get_binary_op_type(a, b);
 | 
				
			||||||
  set_binary_op_output_data(a, b, outputs[0], bopt);
 | 
					  set_binary_op_output_data(a, b, outputs[0], bopt);
 | 
				
			||||||
  set_binary_op_output_data(a, b, outputs[1], bopt);
 | 
					  set_binary_op_output_data(a, b, outputs[1], bopt);
 | 
				
			||||||
  binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
 | 
					  binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void DivMod::eval_gpu(
 | 
					void DivMod::eval_gpu(
 | 
				
			||||||
@@ -255,7 +329,7 @@ void DivMod::eval_gpu(
 | 
				
			|||||||
    std::vector<array>& outputs) {
 | 
					    std::vector<array>& outputs) {
 | 
				
			||||||
  nvtx3::scoped_range r("DivMod::eval_gpu");
 | 
					  nvtx3::scoped_range r("DivMod::eval_gpu");
 | 
				
			||||||
  auto& s = outputs[0].primitive().stream();
 | 
					  auto& s = outputs[0].primitive().stream();
 | 
				
			||||||
  binary_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
 | 
					  binary_two_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlx::core
 | 
					} // namespace mlx::core
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,19 +10,43 @@ namespace cu {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
namespace cg = cooperative_groups;
 | 
					namespace cg = cooperative_groups;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename In, typename Out, typename IdxT>
 | 
					template <typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void copy_s(const In* in, Out* out, IdxT size) {
 | 
					__global__ void copy_s(const In* in, Out* out, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  if (index < size) {
 | 
					
 | 
				
			||||||
    out[index] = CastOp<In, Out>{}(in[0]);
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
 | 
					      out[i] = cast_to<Out>(in[0]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_vec;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					    for (int i = 0; i < N_READS; ++i) {
 | 
				
			||||||
 | 
					      out_vec.val[i] = cast_to<Out>(in[0]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out, index, out_vec);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename In, typename Out, typename IdxT>
 | 
					template <typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void copy_v(const In* in, Out* out, IdxT size) {
 | 
					__global__ void copy_v(const In* in, Out* out, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  if (index < size) {
 | 
					
 | 
				
			||||||
    out[index] = CastOp<In, Out>{}(in[index]);
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
 | 
					      out[i] = cast_to<Out>(in[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto in_vec = load_vector<N_READS>(in, index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_vec;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					    for (int i = 0; i < N_READS; ++i) {
 | 
				
			||||||
 | 
					      out_vec.val[i] = cast_to<Out>(in_vec.val[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out, index, out_vec);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -41,12 +65,19 @@ void copy_contiguous(
 | 
				
			|||||||
        using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
 | 
					        using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
 | 
				
			||||||
        using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
 | 
					        using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
 | 
				
			||||||
        using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
					        using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
				
			||||||
        auto kernel = cu::copy_s<InType, OutType, IdxT>;
 | 
					        // TODO: Choose optimized value based on type size.
 | 
				
			||||||
 | 
					        constexpr int N_READS = 4;
 | 
				
			||||||
 | 
					        auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
 | 
				
			||||||
        if (ctype == CopyType::Vector) {
 | 
					        if (ctype == CopyType::Vector) {
 | 
				
			||||||
          kernel = cu::copy_v<InType, OutType, IdxT>;
 | 
					          kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        auto [num_blocks, block_dims] = get_launch_args(
 | 
					        auto [num_blocks, block_dims] = get_launch_args(
 | 
				
			||||||
            kernel, out.data_size(), out.shape(), out.strides(), large());
 | 
					            kernel,
 | 
				
			||||||
 | 
					            out.data_size(),
 | 
				
			||||||
 | 
					            out.shape(),
 | 
				
			||||||
 | 
					            out.strides(),
 | 
				
			||||||
 | 
					            large(),
 | 
				
			||||||
 | 
					            N_READS);
 | 
				
			||||||
        encoder.add_kernel_node(
 | 
					        encoder.add_kernel_node(
 | 
				
			||||||
            kernel,
 | 
					            kernel,
 | 
				
			||||||
            num_blocks,
 | 
					            num_blocks,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,12 +15,27 @@ namespace cu {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
namespace cg = cooperative_groups;
 | 
					namespace cg = cooperative_groups;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename T, typename IdxT>
 | 
					template <typename Op, typename T, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void
 | 
					__global__ void
 | 
				
			||||||
ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
 | 
					ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  if (index < size) {
 | 
					
 | 
				
			||||||
    out[index] = Op{}(a[index], b[index], c[index]);
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
 | 
					      out[i] = Op{}(a[i], b[i], c[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto a_vec = load_vector<N_READS>(a, index);
 | 
				
			||||||
 | 
					    auto b_vec = load_vector<N_READS>(b, index);
 | 
				
			||||||
 | 
					    auto c_vec = load_vector<N_READS>(c, index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    AlignedVector<T, N_READS> out_vec;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					    for (int i = 0; i < N_READS; ++i) {
 | 
				
			||||||
 | 
					      out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out, index, out_vec);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -149,11 +164,18 @@ void ternary_op_gpu_inplace(
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
          });
 | 
					          });
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
 | 
					      dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
 | 
				
			||||||
        using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
					        using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
				
			||||||
        auto kernel = cu::ternary_v<Op, DType, IdxT>;
 | 
					        // TODO: Choose optimized value based on type size.
 | 
				
			||||||
 | 
					        constexpr int N_READS = 4;
 | 
				
			||||||
 | 
					        auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
 | 
				
			||||||
        auto [num_blocks, block_dims] = get_launch_args(
 | 
					        auto [num_blocks, block_dims] = get_launch_args(
 | 
				
			||||||
            kernel, out.data_size(), out.shape(), out.strides(), large());
 | 
					            kernel,
 | 
				
			||||||
 | 
					            out.data_size(),
 | 
				
			||||||
 | 
					            out.shape(),
 | 
				
			||||||
 | 
					            out.strides(),
 | 
				
			||||||
 | 
					            large(),
 | 
				
			||||||
 | 
					            N_READS);
 | 
				
			||||||
        encoder.add_kernel_node(
 | 
					        encoder.add_kernel_node(
 | 
				
			||||||
            kernel,
 | 
					            kernel,
 | 
				
			||||||
            num_blocks,
 | 
					            num_blocks,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,11 +18,24 @@ namespace cu {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
namespace cg = cooperative_groups;
 | 
					namespace cg = cooperative_groups;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename Op, typename In, typename Out, typename IdxT>
 | 
					template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
 | 
				
			||||||
__global__ void unary_v(const In* in, Out* out, IdxT size) {
 | 
					__global__ void unary_v(const In* in, Out* out, IdxT size) {
 | 
				
			||||||
  IdxT index = cg::this_grid().thread_rank();
 | 
					  IdxT index = cg::this_grid().thread_rank();
 | 
				
			||||||
  if (index < size) {
 | 
					
 | 
				
			||||||
    out[index] = Op{}(in[index]);
 | 
					  if ((index + 1) * N_READS > size) {
 | 
				
			||||||
 | 
					    for (IdxT i = index * N_READS; i < size; ++i) {
 | 
				
			||||||
 | 
					      out[i] = Op{}(in[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto in_vec = load_vector<N_READS>(in, index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    AlignedVector<Out, N_READS> out_vec;
 | 
				
			||||||
 | 
					#pragma unroll
 | 
				
			||||||
 | 
					    for (int i = 0; i < N_READS; ++i) {
 | 
				
			||||||
 | 
					      out_vec.val[i] = Op{}(in_vec.val[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    store_vector<N_READS>(out, index, out_vec);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -112,14 +125,20 @@ void unary_op_gpu_inplace(
 | 
				
			|||||||
      using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
 | 
					      using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
 | 
				
			||||||
      if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
 | 
					      if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
 | 
				
			||||||
        dispatch_bool(large, [&](auto large) {
 | 
					        dispatch_bool(large, [&](auto large) {
 | 
				
			||||||
          using IdxT = std::conditional_t<large(), int64_t, int32_t>;
 | 
					 | 
				
			||||||
          using InType = cuda_type_t<CTYPE_IN>;
 | 
					          using InType = cuda_type_t<CTYPE_IN>;
 | 
				
			||||||
          using OutType = cuda_type_t<CTYPE_OUT>;
 | 
					          using OutType = cuda_type_t<CTYPE_OUT>;
 | 
				
			||||||
          using IdxT = std::conditional_t<large(), int64_t, int32_t>;
 | 
					 | 
				
			||||||
          if (contig) {
 | 
					          if (contig) {
 | 
				
			||||||
            auto kernel = cu::unary_v<Op, InType, OutType, IdxT>;
 | 
					            using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
 | 
				
			||||||
 | 
					            // TODO: Choose optimized value based on type size.
 | 
				
			||||||
 | 
					            constexpr int N_READS = 4;
 | 
				
			||||||
 | 
					            auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
 | 
				
			||||||
            auto [num_blocks, block_dims] = get_launch_args(
 | 
					            auto [num_blocks, block_dims] = get_launch_args(
 | 
				
			||||||
                kernel, out.data_size(), out.shape(), out.strides(), large);
 | 
					                kernel,
 | 
				
			||||||
 | 
					                out.data_size(),
 | 
				
			||||||
 | 
					                out.shape(),
 | 
				
			||||||
 | 
					                out.strides(),
 | 
				
			||||||
 | 
					                large,
 | 
				
			||||||
 | 
					                N_READS);
 | 
				
			||||||
            encoder.add_kernel_node(
 | 
					            encoder.add_kernel_node(
 | 
				
			||||||
                kernel,
 | 
					                kernel,
 | 
				
			||||||
                num_blocks,
 | 
					                num_blocks,
 | 
				
			||||||
@@ -128,6 +147,7 @@ void unary_op_gpu_inplace(
 | 
				
			|||||||
                out.data<OutType>(),
 | 
					                out.data<OutType>(),
 | 
				
			||||||
                out.data_size());
 | 
					                out.data_size());
 | 
				
			||||||
          } else {
 | 
					          } else {
 | 
				
			||||||
 | 
					            using IdxT = std::conditional_t<large(), int64_t, int32_t>;
 | 
				
			||||||
            auto [shape, strides] = collapse_contiguous_dims(in);
 | 
					            auto [shape, strides] = collapse_contiguous_dims(in);
 | 
				
			||||||
            auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
 | 
					            auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
 | 
				
			||||||
            auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
 | 
					            auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user