From e7d2ebadd29e7a76379dba048a419f1e17e82b32 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Jul 2025 15:45:44 -0700 Subject: [PATCH] [CUDA] Affine quantize (#2354) * affine quantize and dequantize kernels * format * fix * format --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/primitives.cu | 1 - mlx/backend/cuda/quantized.cu | 383 ++++++++++++++++++++++++++++++++ python/tests/cuda_skip.py | 1 - 4 files changed, 384 insertions(+), 2 deletions(-) create mode 100644 mlx/backend/cuda/quantized.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 29f2eeab6..9f236b4ea 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -42,6 +42,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 3a3f8ff54..a7f4e8f66 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -91,7 +91,6 @@ NO_GPU_MULTI(Eigh) namespace fast { NO_GPU(ScaledDotProductAttention) -NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) } // namespace fast diff --git a/mlx/backend/cuda/quantized.cu b/mlx/backend/cuda/quantized.cu new file mode 100644 index 000000000..12a1f6fe4 --- /dev/null +++ b/mlx/backend/cuda/quantized.cu @@ -0,0 +1,383 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include + +namespace mlx::core { +namespace cu { + +namespace cg = cooperative_groups; + +template +inline constexpr __device__ short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr __device__ short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +__global__ void +affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim = cg::this_grid().dim_threads(); + constexpr float eps = 1e-7; + constexpr int simd_size = WARP_SIZE; + constexpr float n_bins = (1 << bits) - 1; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_reduce = group_size / simd_size; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; + constexpr int writes_per_pack = + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + + size_t offset = tidx + grid_dim.x * size_t(tidy); + size_t in_index = offset * values_per_reduce; + if (in_index >= size) { + return; + } + size_t out_index = power_of_2_bits + ? offset * writes_per_pack + : offset * bytes_per_pack / writes_per_reduce; + + float w_thread[values_per_reduce]; + float w_min = Limits::max(); + float w_max = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + float val = w[in_index + i]; + w_thread[i] = val; + w_min = min(w_min, val); + w_max = max(w_max, val); + } + + cg::greater max_op; + cg::less min_op; + auto warp = cg::tiled_partition(cg::this_thread_block()); + + w_min = cg::reduce(warp, w_min, min_op); + w_max = cg::reduce(warp, w_max, max_op); + + float scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + float bias = at_zero ? 0 : edge; + + // Write out the scales and biases + size_t gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = static_cast(scale); + biases[gindex] = static_cast(bias); + } + + using OutType = std::conditional_t; + OutType output = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output |= val << (bits * (i % pack_factor)); + } + + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; + output = 0; + } else { +#pragma clang loop unroll(full) + for (int j = 1; j < writes_per_reduce; j++) { + uint8_t sval = warp.shfl_down(val, j); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); + } + } + } + if constexpr (bits == 3 || bits == 6) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else if constexpr (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } + } else { + if constexpr (writes_per_reduce > 0) { + if (out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; + } + } + } +} + +template +__global__ void affine_dequantize( + const uint8_t* w, + const T* scales, + const T* biases, + T* out, + size_t size) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim = cg::this_grid().dim_threads(); + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + size_t offset = tidx + grid_dim.x * size_t(tidy); + size_t oindex = offset * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + out += oindex; + + if constexpr (bits == 3) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x7) * scale + bias; + out[1] = static_cast((w[0] & 0x38) >> 3) * scale + bias; + out[2] = (static_cast((w[0] & 0xc0) >> 6) + + static_cast((w[1] & 0x1) << 2)) * + scale + + bias; + out[3] = static_cast((w[1] & 0xe) >> 1) * scale + bias; + out[4] = static_cast((w[1] & 0x70) >> 4) * scale + bias; + out[5] = (static_cast((w[1] & 0x80) >> 7) + + static_cast((w[2] & 0x3) << 1)) * + scale + + bias; + out[6] = static_cast((w[2] & 0x1c) >> 2) * scale + bias; + out[7] = static_cast((w[2] & 0xe0) >> 5) * scale + bias; + } else if constexpr (bits == 5) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x1f) * scale + bias; + out[1] = (static_cast((w[0] & 0xe0) >> 5) + + static_cast((w[1] & 0x3) << 3)) * + scale + + bias; + out[2] = static_cast((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (static_cast((w[1] & 0x80) >> 7) + + static_cast((w[2] & 0xf) << 1)) * + scale + + bias; + out[4] = (static_cast((w[2] & 0xf0) >> 4) + + static_cast((w[3] & 0x1) << 4)) * + scale + + bias; + out[5] = static_cast((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (static_cast((w[3] & 0xc0) >> 6) + + static_cast((w[4] & 0x7) << 2)) * + scale + + bias; + out[7] = static_cast((w[4] & 0xf8) >> 3) * scale + bias; + } else if constexpr (bits == 6) { + w += offset * bytes_per_pack; + out[0] = static_cast(w[0] & 0x3f) * scale + bias; + out[1] = (static_cast((w[0] >> 6) & 0x03) + + static_cast((w[1] & 0x0f) << 2)) * + scale + + bias; + out[2] = (static_cast((w[1] >> 4) & 0x0f) + + static_cast((w[2] & 0x03) << 4)) * + scale + + bias; + out[3] = static_cast((w[2] >> 2) & 0x3f) * scale + bias; + } else { + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = scale * static_cast(d) + bias; + } + } +} + +} // namespace cu +namespace { + +inline array ensure_row_contiguous( + const array& x, + cu::CommandEncoder& enc, + const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + enc.add_temporary(x_copy); + return x_copy; + } else { + return x; + } +} + +} // namespace + +template +void dispatch_groups(int group_size, F&& f) { + switch (group_size) { + case 32: + f(std::integral_constant{}); + break; + case 64: + f(std::integral_constant{}); + break; + case 128: + f(std::integral_constant{}); + break; + } +} + +template +void dispatch_bits(int bits, F&& f) { + switch (bits) { + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; + case 4: + f(std::integral_constant{}); + break; + case 5: + f(std::integral_constant{}); + break; + case 6: + f(std::integral_constant{}); + break; + case 8: + f(std::integral_constant{}); + break; + } +} + +void fast::AffineQuantize::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& w_pre = inputs[0]; + auto& out = outputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + + auto& s = stream(); + auto& d = cu::device(s.device); + auto& enc = d.get_command_encoder(s); + + auto w = ensure_row_contiguous(w_pre, enc, s); + enc.set_input_array(w); + if (dequantize_) { + auto scales = ensure_row_contiguous(inputs[1], enc, s); + auto biases = ensure_row_contiguous(inputs[2], enc, s); + enc.set_input_array(scales); + enc.set_input_array(biases); + enc.set_output_array(out); + } else { + auto& scales = outputs[1]; + auto& biases = outputs[2]; + scales.set_data(allocator::malloc(scales.nbytes())); + biases.set_data(allocator::malloc(biases.nbytes())); + enc.set_output_array(out); + enc.set_output_array(scales); + enc.set_output_array(biases); + } + + auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype(); + + // Treat uint32 as uint8 in kernel + int uint8_per_uint32 = 4; + int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8 + : bits_ == 6 ? 4 + : 8 / bits_; + int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE; + size_t size = + dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; + + bool large = size > UINT_MAX; + auto grid_shape = w.shape(); + + if (dequantize_) { + grid_shape.back() *= uint8_per_uint32; + } else { + grid_shape.back() /= per_thread; + } + + dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) { + dispatch_groups(group_size_, [&](auto group_size) { + dispatch_bits(bits_, [&](auto bits) { + using DataType = cuda_type_t; + if (dequantize_) { + auto kernel = cu::affine_dequantize; + auto [num_blocks, block_dims] = + get_launch_args(kernel, size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + w.data(), + inputs[1].data(), + inputs[2].data(), + out.data(), + out.size()); + } else { + auto kernel = cu::affine_quantize; + auto [num_blocks, block_dims] = + get_launch_args(kernel, size, grid_shape, w.strides(), large); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + w.data(), + out.data(), + outputs[1].data(), + outputs[2].data(), + w.size()); + } + }); + }); + }); +} + +} // namespace mlx::core diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 005c612ff..7c9ff84ce 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -83,7 +83,6 @@ cuda_skip = { "TestQuantized.test_qmm_shapes", "TestQuantized.test_qmm_vjp", "TestQuantized.test_qmv", - "TestQuantized.test_quantize_dequantize", "TestQuantized.test_qvm", "TestQuantized.test_qvm_splitk", "TestQuantized.test_small_matrix",