mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Contig uses uint as index and non-contig uses int
This commit is contained in:
@@ -247,7 +247,7 @@ void binary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
|
|||||||
@@ -269,7 +269,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ void ternary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
|
|||||||
@@ -127,8 +127,8 @@ void unary_op_gpu_inplace(
|
|||||||
dispatch_bool(large, [&](auto large) {
|
dispatch_bool(large, [&](auto large) {
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
|
||||||
if (contig) {
|
if (contig) {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
// TODO: Choose optimized value based on type size.
|
// TODO: Choose optimized value based on type size.
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
|
auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
|
||||||
@@ -147,6 +147,7 @@ void unary_op_gpu_inplace(
|
|||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size());
|
out.data_size());
|
||||||
} else {
|
} else {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||||
auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
|
auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
|
|||||||
Reference in New Issue
Block a user