This commit is contained in:
Awni Hannun 2025-06-16 07:46:40 -07:00
parent 14531cb14f
commit 91817a165b
7 changed files with 17 additions and 16 deletions

View File

@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] =
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),

View File

@ -10,13 +10,13 @@
namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
__VA_ARGS__; \
}); \
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
__VA_ARGS__; \
}); \
})
void copy_contiguous(

View File

@ -43,8 +43,8 @@ void copy_contiguous(
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] =
get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,

View File

@ -57,7 +57,6 @@ struct CastOp<
}
};
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {

View File

@ -167,7 +167,8 @@ inline std::tuple<dim3, uint> get_launch_args(
const array& arr,
bool large,
int work_per_thread = 1) {
return get_launch_args(kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
return get_launch_args(
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core

View File

@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE);
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),

View File

@ -28,8 +28,8 @@ constexpr bool supports_unary_op() {
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||