mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-04 01:36:42 +08:00
[CUDA] Quantized refactoring (#2442)
This commit is contained in:
parent
2204182bba
commit
3bf81ed1bd
@ -46,7 +46,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||
|
@ -2,30 +2,17 @@
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_pack_factor() {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_bytes_per_pack() {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
__global__ void
|
||||
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||
@ -240,140 +227,100 @@ __global__ void affine_dequantize(
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
namespace {
|
||||
|
||||
inline array ensure_row_contiguous(
|
||||
const array& x,
|
||||
void affine_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename F>
|
||||
void dispatch_groups(int group_size, F&& f) {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
f(std::integral_constant<int, 32>{});
|
||||
break;
|
||||
case 64:
|
||||
f(std::integral_constant<int, 64>{});
|
||||
break;
|
||||
case 128:
|
||||
f(std::integral_constant<int, 128>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_bits(int bits, F&& f) {
|
||||
switch (bits) {
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
f(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
case 5:
|
||||
f(std::integral_constant<int, 5>{});
|
||||
break;
|
||||
case 6:
|
||||
f(std::integral_constant<int, 6>{});
|
||||
break;
|
||||
case 8:
|
||||
f(std::integral_constant<int, 8>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
auto w = ensure_row_contiguous(w_pre, enc, s);
|
||||
enc.set_input_array(w);
|
||||
if (dequantize_) {
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(out);
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
enc.set_output_array(out);
|
||||
enc.set_output_array(scales);
|
||||
enc.set_output_array(biases);
|
||||
}
|
||||
|
||||
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
|
||||
|
||||
// Treat uint32 as uint8 in kernel
|
||||
int uint8_per_uint32 = 4;
|
||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
||||
: bits_ == 6 ? 4
|
||||
: 8 / bits_;
|
||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
|
||||
size_t size =
|
||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||
// Calculate the number of elements per thread
|
||||
int per_thread = group_size_ / WARP_SIZE;
|
||||
size_t size = w.size() / per_thread;
|
||||
|
||||
// Calculate the thread grid that we need to launch
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
|
||||
if (dequantize_) {
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
} else {
|
||||
grid_shape.back() /= per_thread;
|
||||
}
|
||||
|
||||
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
|
||||
enc.set_input_array(w);
|
||||
enc.set_output_array(wq);
|
||||
enc.set_output_array(scales);
|
||||
enc.set_output_array(biases);
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if (dequantize_) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
cu::affine_dequantize<DataType, group_size.value, bits.value>,
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
w.data<uint8_t>(),
|
||||
inputs[1].data<DataType>(),
|
||||
inputs[2].data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
out.size());
|
||||
} else {
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
cu::affine_quantize<DataType, group_size.value, bits.value>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
w.data<DataType>(),
|
||||
out.data<uint8_t>(),
|
||||
outputs[1].data<DataType>(),
|
||||
outputs[2].data<DataType>(),
|
||||
w.data<T>(),
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
biases.data<T>(),
|
||||
w.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void affine_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& w,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
|
||||
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
int packs_per_int;
|
||||
switch (bits_) {
|
||||
case 3:
|
||||
case 5:
|
||||
packs_per_int = 8;
|
||||
break;
|
||||
case 6:
|
||||
packs_per_int = 4;
|
||||
break;
|
||||
default:
|
||||
packs_per_int = 8 / bits_;
|
||||
}
|
||||
|
||||
size_t size = w.size() / packs_per_int;
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
|
||||
enc.set_input_array(wq);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
biases.data<T>(),
|
||||
w.data<T>(),
|
||||
w.size());
|
||||
});
|
||||
});
|
||||
});
|
72
mlx/backend/cuda/quantized/quantized.cpp
Normal file
72
mlx/backend/cuda/quantized/quantized.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline array ensure_row_contiguous(
|
||||
const array& x,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
inline array ensure_row_contiguous_matrix(
|
||||
const array& x,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
if (dequantize_) {
|
||||
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
auto& w = outputs[0];
|
||||
|
||||
w.set_data(allocator::malloc(w.nbytes()));
|
||||
|
||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto& wq = outputs[0];
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
|
||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
27
mlx/backend/cuda/quantized/quantized.h
Normal file
27
mlx/backend/cuda/quantized/quantized.h
Normal file
@ -0,0 +1,27 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void affine_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void affine_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& w,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_pack_factor() {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_bytes_per_pack() {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename F>
|
||||
void dispatch_groups(int group_size, F&& f) {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
f(std::integral_constant<int, 32>{});
|
||||
break;
|
||||
case 64:
|
||||
f(std::integral_constant<int, 64>{});
|
||||
break;
|
||||
case 128:
|
||||
f(std::integral_constant<int, 128>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_bits(int bits, F&& f) {
|
||||
switch (bits) {
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
f(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
case 5:
|
||||
f(std::integral_constant<int, 5>{});
|
||||
break;
|
||||
case 6:
|
||||
f(std::integral_constant<int, 6>{});
|
||||
break;
|
||||
case 8:
|
||||
f(std::integral_constant<int, 8>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user