From 91817a165b203776f8ccfc2f99d32a7ef39235fe Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Jun 2025 07:46:40 -0700 Subject: [PATCH] format --- mlx/backend/cuda/binary.cu | 4 ++-- mlx/backend/cuda/copy/copy.cuh | 14 +++++++------- mlx/backend/cuda/copy/copy_contiguous.cu | 4 ++-- mlx/backend/cuda/device/cast_op.cuh | 1 - mlx/backend/cuda/kernel_utils.cuh | 3 ++- mlx/backend/cuda/ternary.cu | 3 ++- mlx/backend/cuda/unary.cu | 4 ++-- 7 files changed, 17 insertions(+), 16 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 0d2389de1..d4df06f18 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -196,8 +196,8 @@ void binary_op_gpu_inplace( } else if (bopt == BinaryOpType::VectorVector) { kernel = cu::binary_vv; } - auto [num_blocks, block_dims] = - get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index ee5120274..789826507 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -10,13 +10,13 @@ namespace mlx::core { -#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ - using InType = cuda_type_t; \ - using OutType = cuda_type_t; \ - __VA_ARGS__; \ - }); \ +#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ + MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ + using InType = cuda_type_t; \ + using OutType = cuda_type_t; \ + __VA_ARGS__; \ + }); \ }) void copy_contiguous( diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 854fd93b4..5f4c9ca8f 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -43,8 +43,8 @@ void copy_contiguous( if (ctype == CopyType::Vector) { kernel = cu::copy_v; } - auto [num_blocks, block_dims] = - get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( in.data() + in_offset, out.data() + out_offset, diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index 115395db7..f15270432 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -57,7 +57,6 @@ struct CastOp< } }; - // Return an iterator that cast the value to DstT using CastOp. template __host__ __device__ auto make_cast_iterator(Iterator it) { diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 59b48f886..b1fe875bd 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -167,7 +167,8 @@ inline std::tuple get_launch_args( const array& arr, bool large, int work_per_thread = 1) { - return get_launch_args(kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); + return get_launch_args( + kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); } } // namespace mlx::core diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 41441ff40..02e46afc1 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -142,7 +142,8 @@ void ternary_op_gpu_inplace( MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { using IdxT = std::conditional_t; auto kernel = cu::ternary_v; - auto [num_blocks, block_dims] = get_launch_args(kernel, out.data_size(), out.shape(), out.strides(), LARGE); + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 1cff4665e..d2fa96381 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -28,8 +28,8 @@ constexpr bool supports_unary_op() { std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v || std::is_same_v) { + std::is_same_v || std::is_same_v || + std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v ||