mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-04 01:36:42 +08:00
[CUDA] Quantized refactoring (#2442)
This commit is contained in:
parent
2204182bba
commit
3bf81ed1bd
@ -46,7 +46,8 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|
||||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||||
|
@ -2,30 +2,17 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <int bits, int wsize = 8>
|
|
||||||
inline constexpr __device__ short get_pack_factor() {
|
|
||||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int bits, int wsize = 8>
|
|
||||||
inline constexpr __device__ short get_bytes_per_pack() {
|
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
||||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
__global__ void
|
__global__ void
|
||||||
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||||
@ -240,140 +227,100 @@ __global__ void affine_dequantize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
namespace {
|
|
||||||
|
|
||||||
inline array ensure_row_contiguous(
|
void affine_quantize(
|
||||||
const array& x,
|
const array& w,
|
||||||
|
array& wq,
|
||||||
|
array& scales,
|
||||||
|
array& biases,
|
||||||
|
int group_size_,
|
||||||
|
int bits_,
|
||||||
cu::CommandEncoder& enc,
|
cu::CommandEncoder& enc,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
// Calculate the number of elements per thread
|
||||||
array x_copy = contiguous_copy_gpu(x, s);
|
int per_thread = group_size_ / WARP_SIZE;
|
||||||
enc.add_temporary(x_copy);
|
size_t size = w.size() / per_thread;
|
||||||
return x_copy;
|
|
||||||
} else {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void dispatch_groups(int group_size, F&& f) {
|
|
||||||
switch (group_size) {
|
|
||||||
case 32:
|
|
||||||
f(std::integral_constant<int, 32>{});
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
f(std::integral_constant<int, 64>{});
|
|
||||||
break;
|
|
||||||
case 128:
|
|
||||||
f(std::integral_constant<int, 128>{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void dispatch_bits(int bits, F&& f) {
|
|
||||||
switch (bits) {
|
|
||||||
case 2:
|
|
||||||
f(std::integral_constant<int, 2>{});
|
|
||||||
break;
|
|
||||||
case 3:
|
|
||||||
f(std::integral_constant<int, 3>{});
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
f(std::integral_constant<int, 4>{});
|
|
||||||
break;
|
|
||||||
case 5:
|
|
||||||
f(std::integral_constant<int, 5>{});
|
|
||||||
break;
|
|
||||||
case 6:
|
|
||||||
f(std::integral_constant<int, 6>{});
|
|
||||||
break;
|
|
||||||
case 8:
|
|
||||||
f(std::integral_constant<int, 8>{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void fast::AffineQuantize::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
auto& w_pre = inputs[0];
|
|
||||||
auto& out = outputs[0];
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
auto& s = stream();
|
|
||||||
auto& d = cu::device(s.device);
|
|
||||||
auto& enc = d.get_command_encoder(s);
|
|
||||||
|
|
||||||
auto w = ensure_row_contiguous(w_pre, enc, s);
|
|
||||||
enc.set_input_array(w);
|
|
||||||
if (dequantize_) {
|
|
||||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
|
||||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
|
||||||
enc.set_input_array(scales);
|
|
||||||
enc.set_input_array(biases);
|
|
||||||
enc.set_output_array(out);
|
|
||||||
} else {
|
|
||||||
auto& scales = outputs[1];
|
|
||||||
auto& biases = outputs[2];
|
|
||||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
|
||||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
|
||||||
enc.set_output_array(out);
|
|
||||||
enc.set_output_array(scales);
|
|
||||||
enc.set_output_array(biases);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
|
|
||||||
|
|
||||||
// Treat uint32 as uint8 in kernel
|
|
||||||
int uint8_per_uint32 = 4;
|
|
||||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
|
||||||
: bits_ == 6 ? 4
|
|
||||||
: 8 / bits_;
|
|
||||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
|
|
||||||
size_t size =
|
|
||||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
|
||||||
|
|
||||||
|
// Calculate the thread grid that we need to launch
|
||||||
bool large = size > UINT_MAX;
|
bool large = size > UINT_MAX;
|
||||||
auto grid_shape = w.shape();
|
auto grid_shape = w.shape();
|
||||||
|
grid_shape.back() /= per_thread;
|
||||||
|
|
||||||
if (dequantize_) {
|
enc.set_input_array(w);
|
||||||
grid_shape.back() *= uint8_per_uint32;
|
enc.set_output_array(wq);
|
||||||
} else {
|
enc.set_output_array(scales);
|
||||||
grid_shape.back() /= per_thread;
|
enc.set_output_array(biases);
|
||||||
}
|
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||||
|
|
||||||
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
|
|
||||||
dispatch_groups(group_size_, [&](auto group_size) {
|
dispatch_groups(group_size_, [&](auto group_size) {
|
||||||
dispatch_bits(bits_, [&](auto bits) {
|
dispatch_bits(bits_, [&](auto bits) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
if (dequantize_) {
|
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(size, grid_shape, w.strides(), large);
|
get_launch_args(size, grid_shape, w.strides(), large);
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
cu::affine_dequantize<DataType, group_size.value, bits.value>,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
w.data<uint8_t>(),
|
w.data<T>(),
|
||||||
inputs[1].data<DataType>(),
|
wq.data<uint8_t>(),
|
||||||
inputs[2].data<DataType>(),
|
scales.data<T>(),
|
||||||
out.data<DataType>(),
|
biases.data<T>(),
|
||||||
out.size());
|
w.size());
|
||||||
} else {
|
});
|
||||||
auto [num_blocks, block_dims] =
|
});
|
||||||
get_launch_args(size, grid_shape, w.strides(), large);
|
});
|
||||||
enc.add_kernel_node(
|
}
|
||||||
cu::affine_quantize<DataType, group_size.value, bits.value>,
|
|
||||||
num_blocks,
|
void affine_dequantize(
|
||||||
block_dims,
|
const array& wq,
|
||||||
w.data<DataType>(),
|
const array& scales,
|
||||||
out.data<uint8_t>(),
|
const array& biases,
|
||||||
outputs[1].data<DataType>(),
|
array& w,
|
||||||
outputs[2].data<DataType>(),
|
int group_size_,
|
||||||
w.size());
|
int bits_,
|
||||||
}
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
|
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
|
||||||
|
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
|
||||||
|
constexpr int uint8_per_uint32 = 4;
|
||||||
|
int packs_per_int;
|
||||||
|
switch (bits_) {
|
||||||
|
case 3:
|
||||||
|
case 5:
|
||||||
|
packs_per_int = 8;
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
packs_per_int = 4;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
packs_per_int = 8 / bits_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size = w.size() / packs_per_int;
|
||||||
|
bool large = size > UINT_MAX;
|
||||||
|
auto grid_shape = w.shape();
|
||||||
|
grid_shape.back() *= uint8_per_uint32;
|
||||||
|
|
||||||
|
enc.set_input_array(wq);
|
||||||
|
enc.set_input_array(scales);
|
||||||
|
enc.set_input_array(biases);
|
||||||
|
enc.set_output_array(w);
|
||||||
|
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||||
|
dispatch_groups(group_size_, [&](auto group_size) {
|
||||||
|
dispatch_bits(bits_, [&](auto bits) {
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(size, grid_shape, w.strides(), large);
|
||||||
|
enc.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
wq.data<uint8_t>(),
|
||||||
|
scales.data<T>(),
|
||||||
|
biases.data<T>(),
|
||||||
|
w.data<T>(),
|
||||||
|
w.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
72
mlx/backend/cuda/quantized/quantized.cpp
Normal file
72
mlx/backend/cuda/quantized/quantized.cpp
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline array ensure_row_contiguous(
|
||||||
|
const array& x,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
|
if (!x.flags().row_contiguous) {
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
enc.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array ensure_row_contiguous_matrix(
|
||||||
|
const array& x,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
|
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||||
|
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||||
|
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
enc.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void fast::AffineQuantize::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = cu::device(s.device);
|
||||||
|
auto& enc = d.get_command_encoder(s);
|
||||||
|
|
||||||
|
if (dequantize_) {
|
||||||
|
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||||
|
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||||
|
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||||
|
auto& w = outputs[0];
|
||||||
|
|
||||||
|
w.set_data(allocator::malloc(w.nbytes()));
|
||||||
|
|
||||||
|
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||||
|
} else {
|
||||||
|
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||||
|
auto& wq = outputs[0];
|
||||||
|
auto& scales = outputs[1];
|
||||||
|
auto& biases = outputs[2];
|
||||||
|
|
||||||
|
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||||
|
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||||
|
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||||
|
|
||||||
|
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
27
mlx/backend/cuda/quantized/quantized.h
Normal file
27
mlx/backend/cuda/quantized/quantized.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void affine_quantize(
|
||||||
|
const array& w,
|
||||||
|
array& wq,
|
||||||
|
array& scales,
|
||||||
|
array& biases,
|
||||||
|
int group_size_,
|
||||||
|
int bits_,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
void affine_dequantize(
|
||||||
|
const array& wq,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
array& w,
|
||||||
|
int group_size_,
|
||||||
|
int bits_,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <int bits, int wsize = 8>
|
||||||
|
inline constexpr __device__ short get_pack_factor() {
|
||||||
|
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int bits, int wsize = 8>
|
||||||
|
inline constexpr __device__ short get_bytes_per_pack() {
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_groups(int group_size, F&& f) {
|
||||||
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
f(std::integral_constant<int, 32>{});
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
f(std::integral_constant<int, 64>{});
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
f(std::integral_constant<int, 128>{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_bits(int bits, F&& f) {
|
||||||
|
switch (bits) {
|
||||||
|
case 2:
|
||||||
|
f(std::integral_constant<int, 2>{});
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
f(std::integral_constant<int, 3>{});
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
f(std::integral_constant<int, 4>{});
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
f(std::integral_constant<int, 5>{});
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
f(std::integral_constant<int, 6>{});
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
f(std::integral_constant<int, 8>{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user