mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int32_t for IdxT
This commit is contained in:
@@ -268,7 +268,7 @@ void binary_op_gpu_inplace(
|
|||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
|||||||
@@ -286,7 +286,7 @@ void binary_op_gpu_inplace(
|
|||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
|||||||
@@ -71,10 +71,10 @@ void copy_contiguous(
|
|||||||
int64_t out_offset) {
|
int64_t out_offset) {
|
||||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
||||||
|
|||||||
@@ -34,10 +34,9 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
|||||||
auto b_vec = load_vector<N_READS>(b, index);
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
auto c_vec = load_vector<N_READS>(c, index);
|
auto c_vec = load_vector<N_READS>(c, index);
|
||||||
|
|
||||||
AlignedVector<Out, N_READS> out_vec;
|
AlignedVector<T, N_READS> out_vec;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
out_vec.val[i] = CastOp<In, Out>{}(a_vec.val[i]);
|
|
||||||
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
|
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,7 +170,7 @@ void ternary_op_gpu_inplace(
|
|||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
|
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
|
||||||
|
|||||||
Reference in New Issue
Block a user