mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	[CUDA] Quantized refactoring (#2442)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							2204182bba
						
					
				
				
					commit
					3bf81ed1bd
				
			@@ -46,7 +46,8 @@ target_sources(
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
 | 
			
		||||
          ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
 | 
			
		||||
 | 
			
		||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
 | 
			
		||||
 
 | 
			
		||||
@@ -2,30 +2,17 @@
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/cuda/device.h"
 | 
			
		||||
#include "mlx/backend/cuda/kernel_utils.cuh"
 | 
			
		||||
#include "mlx/backend/gpu/copy.h"
 | 
			
		||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
 | 
			
		||||
#include "mlx/dtype_utils.h"
 | 
			
		||||
#include "mlx/fast_primitives.h"
 | 
			
		||||
 | 
			
		||||
#include <cooperative_groups.h>
 | 
			
		||||
#include <cooperative_groups/reduce.h>
 | 
			
		||||
#include <nvtx3/nvtx3.hpp>
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
namespace cu {
 | 
			
		||||
 | 
			
		||||
namespace cg = cooperative_groups;
 | 
			
		||||
 | 
			
		||||
template <int bits, int wsize = 8>
 | 
			
		||||
inline constexpr __device__ short get_pack_factor() {
 | 
			
		||||
  return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int bits, int wsize = 8>
 | 
			
		||||
inline constexpr __device__ short get_bytes_per_pack() {
 | 
			
		||||
  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
 | 
			
		||||
  return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, int group_size, int bits>
 | 
			
		||||
__global__ void
 | 
			
		||||
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
 | 
			
		||||
@@ -240,140 +227,100 @@ __global__ void affine_dequantize(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace cu
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
inline array ensure_row_contiguous(
 | 
			
		||||
    const array& x,
 | 
			
		||||
void affine_quantize(
 | 
			
		||||
    const array& w,
 | 
			
		||||
    array& wq,
 | 
			
		||||
    array& scales,
 | 
			
		||||
    array& biases,
 | 
			
		||||
    int group_size_,
 | 
			
		||||
    int bits_,
 | 
			
		||||
    cu::CommandEncoder& enc,
 | 
			
		||||
    const Stream& s) {
 | 
			
		||||
  if (!x.flags().row_contiguous) {
 | 
			
		||||
    array x_copy = contiguous_copy_gpu(x, s);
 | 
			
		||||
    enc.add_temporary(x_copy);
 | 
			
		||||
    return x_copy;
 | 
			
		||||
  } else {
 | 
			
		||||
    return x;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
void dispatch_groups(int group_size, F&& f) {
 | 
			
		||||
  switch (group_size) {
 | 
			
		||||
    case 32:
 | 
			
		||||
      f(std::integral_constant<int, 32>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 64:
 | 
			
		||||
      f(std::integral_constant<int, 64>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 128:
 | 
			
		||||
      f(std::integral_constant<int, 128>{});
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
void dispatch_bits(int bits, F&& f) {
 | 
			
		||||
  switch (bits) {
 | 
			
		||||
    case 2:
 | 
			
		||||
      f(std::integral_constant<int, 2>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 3:
 | 
			
		||||
      f(std::integral_constant<int, 3>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 4:
 | 
			
		||||
      f(std::integral_constant<int, 4>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 5:
 | 
			
		||||
      f(std::integral_constant<int, 5>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 6:
 | 
			
		||||
      f(std::integral_constant<int, 6>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 8:
 | 
			
		||||
      f(std::integral_constant<int, 8>{});
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void fast::AffineQuantize::eval_gpu(
 | 
			
		||||
    const std::vector<array>& inputs,
 | 
			
		||||
    std::vector<array>& outputs) {
 | 
			
		||||
  auto& w_pre = inputs[0];
 | 
			
		||||
  auto& out = outputs[0];
 | 
			
		||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
			
		||||
 | 
			
		||||
  auto& s = stream();
 | 
			
		||||
  auto& d = cu::device(s.device);
 | 
			
		||||
  auto& enc = d.get_command_encoder(s);
 | 
			
		||||
 | 
			
		||||
  auto w = ensure_row_contiguous(w_pre, enc, s);
 | 
			
		||||
  enc.set_input_array(w);
 | 
			
		||||
  if (dequantize_) {
 | 
			
		||||
    auto scales = ensure_row_contiguous(inputs[1], enc, s);
 | 
			
		||||
    auto biases = ensure_row_contiguous(inputs[2], enc, s);
 | 
			
		||||
    enc.set_input_array(scales);
 | 
			
		||||
    enc.set_input_array(biases);
 | 
			
		||||
    enc.set_output_array(out);
 | 
			
		||||
  } else {
 | 
			
		||||
    auto& scales = outputs[1];
 | 
			
		||||
    auto& biases = outputs[2];
 | 
			
		||||
    scales.set_data(allocator::malloc(scales.nbytes()));
 | 
			
		||||
    biases.set_data(allocator::malloc(biases.nbytes()));
 | 
			
		||||
    enc.set_output_array(out);
 | 
			
		||||
    enc.set_output_array(scales);
 | 
			
		||||
    enc.set_output_array(biases);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
 | 
			
		||||
 | 
			
		||||
  // Treat uint32 as uint8 in kernel
 | 
			
		||||
  int uint8_per_uint32 = 4;
 | 
			
		||||
  int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
 | 
			
		||||
      : bits_ == 6                               ? 4
 | 
			
		||||
                                                 : 8 / bits_;
 | 
			
		||||
  int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
 | 
			
		||||
  size_t size =
 | 
			
		||||
      dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
 | 
			
		||||
  // Calculate the number of elements per thread
 | 
			
		||||
  int per_thread = group_size_ / WARP_SIZE;
 | 
			
		||||
  size_t size = w.size() / per_thread;
 | 
			
		||||
 | 
			
		||||
  // Calculate the thread grid that we need to launch
 | 
			
		||||
  bool large = size > UINT_MAX;
 | 
			
		||||
  auto grid_shape = w.shape();
 | 
			
		||||
  grid_shape.back() /= per_thread;
 | 
			
		||||
 | 
			
		||||
  if (dequantize_) {
 | 
			
		||||
    grid_shape.back() *= uint8_per_uint32;
 | 
			
		||||
  } else {
 | 
			
		||||
    grid_shape.back() /= per_thread;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
 | 
			
		||||
  enc.set_input_array(w);
 | 
			
		||||
  enc.set_output_array(wq);
 | 
			
		||||
  enc.set_output_array(scales);
 | 
			
		||||
  enc.set_output_array(biases);
 | 
			
		||||
  dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
 | 
			
		||||
    dispatch_groups(group_size_, [&](auto group_size) {
 | 
			
		||||
      dispatch_bits(bits_, [&](auto bits) {
 | 
			
		||||
        using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
 | 
			
		||||
        if (dequantize_) {
 | 
			
		||||
          auto [num_blocks, block_dims] =
 | 
			
		||||
              get_launch_args(size, grid_shape, w.strides(), large);
 | 
			
		||||
          enc.add_kernel_node(
 | 
			
		||||
              cu::affine_dequantize<DataType, group_size.value, bits.value>,
 | 
			
		||||
              num_blocks,
 | 
			
		||||
              block_dims,
 | 
			
		||||
              w.data<uint8_t>(),
 | 
			
		||||
              inputs[1].data<DataType>(),
 | 
			
		||||
              inputs[2].data<DataType>(),
 | 
			
		||||
              out.data<DataType>(),
 | 
			
		||||
              out.size());
 | 
			
		||||
        } else {
 | 
			
		||||
          auto [num_blocks, block_dims] =
 | 
			
		||||
              get_launch_args(size, grid_shape, w.strides(), large);
 | 
			
		||||
          enc.add_kernel_node(
 | 
			
		||||
              cu::affine_quantize<DataType, group_size.value, bits.value>,
 | 
			
		||||
              num_blocks,
 | 
			
		||||
              block_dims,
 | 
			
		||||
              w.data<DataType>(),
 | 
			
		||||
              out.data<uint8_t>(),
 | 
			
		||||
              outputs[1].data<DataType>(),
 | 
			
		||||
              outputs[2].data<DataType>(),
 | 
			
		||||
              w.size());
 | 
			
		||||
        }
 | 
			
		||||
        using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
 | 
			
		||||
        auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
 | 
			
		||||
        auto [num_blocks, block_dims] =
 | 
			
		||||
            get_launch_args(size, grid_shape, w.strides(), large);
 | 
			
		||||
        enc.add_kernel_node(
 | 
			
		||||
            kernel,
 | 
			
		||||
            num_blocks,
 | 
			
		||||
            block_dims,
 | 
			
		||||
            w.data<T>(),
 | 
			
		||||
            wq.data<uint8_t>(),
 | 
			
		||||
            scales.data<T>(),
 | 
			
		||||
            biases.data<T>(),
 | 
			
		||||
            w.size());
 | 
			
		||||
      });
 | 
			
		||||
    });
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void affine_dequantize(
 | 
			
		||||
    const array& wq,
 | 
			
		||||
    const array& scales,
 | 
			
		||||
    const array& biases,
 | 
			
		||||
    array& w,
 | 
			
		||||
    int group_size_,
 | 
			
		||||
    int bits_,
 | 
			
		||||
    cu::CommandEncoder& enc,
 | 
			
		||||
    const Stream& s) {
 | 
			
		||||
  // Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
 | 
			
		||||
  // one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
 | 
			
		||||
  constexpr int uint8_per_uint32 = 4;
 | 
			
		||||
  int packs_per_int;
 | 
			
		||||
  switch (bits_) {
 | 
			
		||||
    case 3:
 | 
			
		||||
    case 5:
 | 
			
		||||
      packs_per_int = 8;
 | 
			
		||||
      break;
 | 
			
		||||
    case 6:
 | 
			
		||||
      packs_per_int = 4;
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      packs_per_int = 8 / bits_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  size_t size = w.size() / packs_per_int;
 | 
			
		||||
  bool large = size > UINT_MAX;
 | 
			
		||||
  auto grid_shape = w.shape();
 | 
			
		||||
  grid_shape.back() *= uint8_per_uint32;
 | 
			
		||||
 | 
			
		||||
  enc.set_input_array(wq);
 | 
			
		||||
  enc.set_input_array(scales);
 | 
			
		||||
  enc.set_input_array(biases);
 | 
			
		||||
  enc.set_output_array(w);
 | 
			
		||||
  dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
 | 
			
		||||
    dispatch_groups(group_size_, [&](auto group_size) {
 | 
			
		||||
      dispatch_bits(bits_, [&](auto bits) {
 | 
			
		||||
        using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
 | 
			
		||||
        auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
 | 
			
		||||
        auto [num_blocks, block_dims] =
 | 
			
		||||
            get_launch_args(size, grid_shape, w.strides(), large);
 | 
			
		||||
        enc.add_kernel_node(
 | 
			
		||||
            kernel,
 | 
			
		||||
            num_blocks,
 | 
			
		||||
            block_dims,
 | 
			
		||||
            wq.data<uint8_t>(),
 | 
			
		||||
            scales.data<T>(),
 | 
			
		||||
            biases.data<T>(),
 | 
			
		||||
            w.data<T>(),
 | 
			
		||||
            w.size());
 | 
			
		||||
      });
 | 
			
		||||
    });
 | 
			
		||||
  });
 | 
			
		||||
							
								
								
									
										72
									
								
								mlx/backend/cuda/quantized/quantized.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								mlx/backend/cuda/quantized/quantized.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,72 @@
 | 
			
		||||
// Copyright © 2025 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/cuda/quantized/quantized.h"
 | 
			
		||||
#include "mlx/backend/cuda/device.h"
 | 
			
		||||
#include "mlx/backend/gpu/copy.h"
 | 
			
		||||
#include "mlx/fast_primitives.h"
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
inline array ensure_row_contiguous(
 | 
			
		||||
    const array& x,
 | 
			
		||||
    cu::CommandEncoder& enc,
 | 
			
		||||
    const Stream& s) {
 | 
			
		||||
  if (!x.flags().row_contiguous) {
 | 
			
		||||
    array x_copy = contiguous_copy_gpu(x, s);
 | 
			
		||||
    enc.add_temporary(x_copy);
 | 
			
		||||
    return x_copy;
 | 
			
		||||
  } else {
 | 
			
		||||
    return x;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline array ensure_row_contiguous_matrix(
 | 
			
		||||
    const array& x,
 | 
			
		||||
    cu::CommandEncoder& enc,
 | 
			
		||||
    const Stream& s) {
 | 
			
		||||
  auto stride_0 = x.strides()[x.ndim() - 2];
 | 
			
		||||
  auto stride_1 = x.strides()[x.ndim() - 1];
 | 
			
		||||
  if (stride_0 == x.shape(-1) && stride_1 == 1) {
 | 
			
		||||
    return x;
 | 
			
		||||
  } else {
 | 
			
		||||
    array x_copy = contiguous_copy_gpu(x, s);
 | 
			
		||||
    enc.add_temporary(x_copy);
 | 
			
		||||
    return x_copy;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
void fast::AffineQuantize::eval_gpu(
 | 
			
		||||
    const std::vector<array>& inputs,
 | 
			
		||||
    std::vector<array>& outputs) {
 | 
			
		||||
  auto& s = stream();
 | 
			
		||||
  auto& d = cu::device(s.device);
 | 
			
		||||
  auto& enc = d.get_command_encoder(s);
 | 
			
		||||
 | 
			
		||||
  if (dequantize_) {
 | 
			
		||||
    auto wq = ensure_row_contiguous(inputs[0], enc, s);
 | 
			
		||||
    auto scales = ensure_row_contiguous(inputs[1], enc, s);
 | 
			
		||||
    auto biases = ensure_row_contiguous(inputs[2], enc, s);
 | 
			
		||||
    auto& w = outputs[0];
 | 
			
		||||
 | 
			
		||||
    w.set_data(allocator::malloc(w.nbytes()));
 | 
			
		||||
 | 
			
		||||
    affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
 | 
			
		||||
  } else {
 | 
			
		||||
    auto w = ensure_row_contiguous(inputs[0], enc, s);
 | 
			
		||||
    auto& wq = outputs[0];
 | 
			
		||||
    auto& scales = outputs[1];
 | 
			
		||||
    auto& biases = outputs[2];
 | 
			
		||||
 | 
			
		||||
    wq.set_data(allocator::malloc(wq.nbytes()));
 | 
			
		||||
    scales.set_data(allocator::malloc(scales.nbytes()));
 | 
			
		||||
    biases.set_data(allocator::malloc(biases.nbytes()));
 | 
			
		||||
 | 
			
		||||
    affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
							
								
								
									
										27
									
								
								mlx/backend/cuda/quantized/quantized.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								mlx/backend/cuda/quantized/quantized.h
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
// Copyright © 2025 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/cuda/device.h"
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
void affine_quantize(
 | 
			
		||||
    const array& w,
 | 
			
		||||
    array& wq,
 | 
			
		||||
    array& scales,
 | 
			
		||||
    array& biases,
 | 
			
		||||
    int group_size_,
 | 
			
		||||
    int bits_,
 | 
			
		||||
    cu::CommandEncoder& enc,
 | 
			
		||||
    const Stream& s);
 | 
			
		||||
 | 
			
		||||
void affine_dequantize(
 | 
			
		||||
    const array& wq,
 | 
			
		||||
    const array& scales,
 | 
			
		||||
    const array& biases,
 | 
			
		||||
    array& w,
 | 
			
		||||
    int group_size_,
 | 
			
		||||
    int bits_,
 | 
			
		||||
    cu::CommandEncoder& enc,
 | 
			
		||||
    const Stream& s);
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
							
								
								
									
										59
									
								
								mlx/backend/cuda/quantized/quantized_utils.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								mlx/backend/cuda/quantized/quantized_utils.cuh
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,59 @@
 | 
			
		||||
// Copyright © 2025 Apple Inc.
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
namespace cu {
 | 
			
		||||
 | 
			
		||||
template <int bits, int wsize = 8>
 | 
			
		||||
inline constexpr __device__ short get_pack_factor() {
 | 
			
		||||
  return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int bits, int wsize = 8>
 | 
			
		||||
inline constexpr __device__ short get_bytes_per_pack() {
 | 
			
		||||
  constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
 | 
			
		||||
  return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace cu
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
void dispatch_groups(int group_size, F&& f) {
 | 
			
		||||
  switch (group_size) {
 | 
			
		||||
    case 32:
 | 
			
		||||
      f(std::integral_constant<int, 32>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 64:
 | 
			
		||||
      f(std::integral_constant<int, 64>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 128:
 | 
			
		||||
      f(std::integral_constant<int, 128>{});
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
void dispatch_bits(int bits, F&& f) {
 | 
			
		||||
  switch (bits) {
 | 
			
		||||
    case 2:
 | 
			
		||||
      f(std::integral_constant<int, 2>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 3:
 | 
			
		||||
      f(std::integral_constant<int, 3>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 4:
 | 
			
		||||
      f(std::integral_constant<int, 4>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 5:
 | 
			
		||||
      f(std::integral_constant<int, 5>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 6:
 | 
			
		||||
      f(std::integral_constant<int, 6>{});
 | 
			
		||||
      break;
 | 
			
		||||
    case 8:
 | 
			
		||||
      f(std::integral_constant<int, 8>{});
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
		Reference in New Issue
	
	Block a user