diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 5bc75e2e0a..8c1b999e99 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -46,7 +46,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) diff --git a/mlx/backend/cuda/quantized.cu b/mlx/backend/cuda/quantized/affine_quantize.cu similarity index 66% rename from mlx/backend/cuda/quantized.cu rename to mlx/backend/cuda/quantized/affine_quantize.cu index 5702fa5a9f..55322fa3e8 100644 --- a/mlx/backend/cuda/quantized.cu +++ b/mlx/backend/cuda/quantized/affine_quantize.cu @@ -2,30 +2,17 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/cuda/quantized/quantized_utils.cuh" #include "mlx/dtype_utils.h" -#include "mlx/fast_primitives.h" #include #include -#include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; -template -inline constexpr __device__ short get_pack_factor() { - return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); -} - -template -inline constexpr __device__ short get_bytes_per_pack() { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); -} - template __global__ void affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { @@ -240,140 +227,100 @@ __global__ void affine_dequantize( } } // namespace cu -namespace { -inline array ensure_row_contiguous( - const array& x, +void affine_quantize( + const array& w, + array& wq, + array& scales, + array& biases, + int group_size_, + int bits_, cu::CommandEncoder& enc, const Stream& s) { - if (!x.flags().row_contiguous) { - array x_copy = contiguous_copy_gpu(x, s); - enc.add_temporary(x_copy); - return x_copy; - } else { - return x; - } -} - -} // namespace - -template -void dispatch_groups(int group_size, F&& f) { - switch (group_size) { - case 32: - f(std::integral_constant{}); - break; - case 64: - f(std::integral_constant{}); - break; - case 128: - f(std::integral_constant{}); - break; - } -} - -template -void dispatch_bits(int bits, F&& f) { - switch (bits) { - case 2: - f(std::integral_constant{}); - break; - case 3: - f(std::integral_constant{}); - break; - case 4: - f(std::integral_constant{}); - break; - case 5: - f(std::integral_constant{}); - break; - case 6: - f(std::integral_constant{}); - break; - case 8: - f(std::integral_constant{}); - break; - } -} - -void fast::AffineQuantize::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - auto& w_pre = inputs[0]; - auto& out = outputs[0]; - out.set_data(allocator::malloc(out.nbytes())); - - auto& s = stream(); - auto& d = cu::device(s.device); - auto& enc = d.get_command_encoder(s); - - auto w = ensure_row_contiguous(w_pre, enc, s); - enc.set_input_array(w); - if (dequantize_) { - auto scales = ensure_row_contiguous(inputs[1], enc, s); - auto biases = ensure_row_contiguous(inputs[2], enc, s); - enc.set_input_array(scales); - enc.set_input_array(biases); - enc.set_output_array(out); - } else { - auto& scales = outputs[1]; - auto& biases = outputs[2]; - scales.set_data(allocator::malloc(scales.nbytes())); - biases.set_data(allocator::malloc(biases.nbytes())); - enc.set_output_array(out); - enc.set_output_array(scales); - enc.set_output_array(biases); - } - - auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype(); - - // Treat uint32 as uint8 in kernel - int uint8_per_uint32 = 4; - int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 - : bits_ == 6 ? 4 - : 8 / bits_; - int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE; - size_t size = - dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; + // Calculate the number of elements per thread + int per_thread = group_size_ / WARP_SIZE; + size_t size = w.size() / per_thread; + // Calculate the thread grid that we need to launch bool large = size > UINT_MAX; auto grid_shape = w.shape(); + grid_shape.back() /= per_thread; - if (dequantize_) { - grid_shape.back() *= uint8_per_uint32; - } else { - grid_shape.back() /= per_thread; - } - - dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) { + enc.set_input_array(w); + enc.set_output_array(wq); + enc.set_output_array(scales); + enc.set_output_array(biases); + dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) { dispatch_groups(group_size_, [&](auto group_size) { dispatch_bits(bits_, [&](auto bits) { - using DataType = cuda_type_t; - if (dequantize_) { - auto [num_blocks, block_dims] = - get_launch_args(size, grid_shape, w.strides(), large); - enc.add_kernel_node( - cu::affine_dequantize, - num_blocks, - block_dims, - w.data(), - inputs[1].data(), - inputs[2].data(), - out.data(), - out.size()); - } else { - auto [num_blocks, block_dims] = - get_launch_args(size, grid_shape, w.strides(), large); - enc.add_kernel_node( - cu::affine_quantize, - num_blocks, - block_dims, - w.data(), - out.data(), - outputs[1].data(), - outputs[2].data(), - w.size()); - } + using T = cuda_type_t; + auto kernel = cu::affine_quantize; + auto [num_blocks, block_dims] = + get_launch_args(size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + w.data(), + wq.data(), + scales.data(), + biases.data(), + w.size()); + }); + }); + }); +} + +void affine_dequantize( + const array& wq, + const array& scales, + const array& biases, + array& w, + int group_size_, + int bits_, + cu::CommandEncoder& enc, + const Stream& s) { + // Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in + // one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8. + constexpr int uint8_per_uint32 = 4; + int packs_per_int; + switch (bits_) { + case 3: + case 5: + packs_per_int = 8; + break; + case 6: + packs_per_int = 4; + break; + default: + packs_per_int = 8 / bits_; + } + + size_t size = w.size() / packs_per_int; + bool large = size > UINT_MAX; + auto grid_shape = w.shape(); + grid_shape.back() *= uint8_per_uint32; + + enc.set_input_array(wq); + enc.set_input_array(scales); + enc.set_input_array(biases); + enc.set_output_array(w); + dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) { + dispatch_groups(group_size_, [&](auto group_size) { + dispatch_bits(bits_, [&](auto bits) { + using T = cuda_type_t; + auto kernel = cu::affine_dequantize; + auto [num_blocks, block_dims] = + get_launch_args(size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + wq.data(), + scales.data(), + biases.data(), + w.data(), + w.size()); }); }); }); diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp new file mode 100644 index 0000000000..f495af53b3 --- /dev/null +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -0,0 +1,72 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/quantized/quantized.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/fast_primitives.h" + +namespace mlx::core { + +namespace { + +inline array ensure_row_contiguous( + const array& x, + cu::CommandEncoder& enc, + const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; + } else { + return x; + } +} + +inline array ensure_row_contiguous_matrix( + const array& x, + cu::CommandEncoder& enc, + const Stream& s) { + auto stride_0 = x.strides()[x.ndim() - 2]; + auto stride_1 = x.strides()[x.ndim() - 1]; + if (stride_0 == x.shape(-1) && stride_1 == 1) { + return x; + } else { + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; + } +} + +} // namespace + +void fast::AffineQuantize::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& d = cu::device(s.device); + auto& enc = d.get_command_encoder(s); + + if (dequantize_) { + auto wq = ensure_row_contiguous(inputs[0], enc, s); + auto scales = ensure_row_contiguous(inputs[1], enc, s); + auto biases = ensure_row_contiguous(inputs[2], enc, s); + auto& w = outputs[0]; + + w.set_data(allocator::malloc(w.nbytes())); + + affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); + } else { + auto w = ensure_row_contiguous(inputs[0], enc, s); + auto& wq = outputs[0]; + auto& scales = outputs[1]; + auto& biases = outputs[2]; + + wq.set_data(allocator::malloc(wq.nbytes())); + scales.set_data(allocator::malloc(scales.nbytes())); + biases.set_data(allocator::malloc(biases.nbytes())); + + affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.h b/mlx/backend/cuda/quantized/quantized.h new file mode 100644 index 0000000000..ec6a08000b --- /dev/null +++ b/mlx/backend/cuda/quantized/quantized.h @@ -0,0 +1,27 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +void affine_quantize( + const array& w, + array& wq, + array& scales, + array& biases, + int group_size_, + int bits_, + cu::CommandEncoder& enc, + const Stream& s); + +void affine_dequantize( + const array& wq, + const array& scales, + const array& biases, + array& w, + int group_size_, + int bits_, + cu::CommandEncoder& enc, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized_utils.cuh b/mlx/backend/cuda/quantized/quantized_utils.cuh new file mode 100644 index 0000000000..c6a85527cc --- /dev/null +++ b/mlx/backend/cuda/quantized/quantized_utils.cuh @@ -0,0 +1,59 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core { + +namespace cu { + +template +inline constexpr __device__ short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr __device__ short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +} // namespace cu + +template +void dispatch_groups(int group_size, F&& f) { + switch (group_size) { + case 32: + f(std::integral_constant{}); + break; + case 64: + f(std::integral_constant{}); + break; + case 128: + f(std::integral_constant{}); + break; + } +} + +template +void dispatch_bits(int bits, F&& f) { + switch (bits) { + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; + case 4: + f(std::integral_constant{}); + break; + case 5: + f(std::integral_constant{}); + break; + case 6: + f(std::integral_constant{}); + break; + case 8: + f(std::integral_constant{}); + break; + } +} + +} // namespace mlx::core