mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Refactor quantized
This commit is contained in:
@@ -22,7 +22,7 @@ project(
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
|
@@ -42,7 +42,9 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
@@ -130,3 +132,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||
|
||||
# Make Thunderkittens available
|
||||
FetchContent_Declare(
|
||||
kittens
|
||||
GIT_REPOSITORY https://github.com/HazyResearch/ThunderKittens.git
|
||||
GIT_TAG aaab847f430ed313ed466e64b25b9177babd1db8
|
||||
GIT_SHALLOW TRUE)
|
||||
FetchContent_MakeAvailable(kittens)
|
||||
target_include_directories(mlx BEFORE PRIVATE "${kittens_SOURCE_DIR}/include")
|
||||
|
@@ -81,7 +81,6 @@ NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(SegmentedMM)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
|
@@ -2,30 +2,17 @@
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_pack_factor() {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_bytes_per_pack() {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
__global__ void
|
||||
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||
@@ -240,144 +227,100 @@ __global__ void affine_dequantize(
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
namespace {
|
||||
|
||||
inline array ensure_row_contiguous(
|
||||
const array& x,
|
||||
void affine_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename F>
|
||||
void dispatch_groups(int group_size, F&& f) {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
f(std::integral_constant<int, 32>{});
|
||||
break;
|
||||
case 64:
|
||||
f(std::integral_constant<int, 64>{});
|
||||
break;
|
||||
case 128:
|
||||
f(std::integral_constant<int, 128>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_bits(int bits, F&& f) {
|
||||
switch (bits) {
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
f(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
case 5:
|
||||
f(std::integral_constant<int, 5>{});
|
||||
break;
|
||||
case 6:
|
||||
f(std::integral_constant<int, 6>{});
|
||||
break;
|
||||
case 8:
|
||||
f(std::integral_constant<int, 8>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
auto w = ensure_row_contiguous(w_pre, enc, s);
|
||||
enc.set_input_array(w);
|
||||
if (dequantize_) {
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(out);
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
enc.set_output_array(out);
|
||||
enc.set_output_array(scales);
|
||||
enc.set_output_array(biases);
|
||||
}
|
||||
|
||||
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
|
||||
|
||||
// Treat uint32 as uint8 in kernel
|
||||
int uint8_per_uint32 = 4;
|
||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
||||
: bits_ == 6 ? 4
|
||||
: 8 / bits_;
|
||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
|
||||
size_t size =
|
||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||
// Calculate the number of elements per thread
|
||||
int per_thread = group_size_ / WARP_SIZE;
|
||||
size_t size = w.size() / per_thread;
|
||||
|
||||
// Calculate the thread grid that we need to launch
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
grid_shape.back() /= per_thread;
|
||||
|
||||
if (dequantize_) {
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
} else {
|
||||
grid_shape.back() /= per_thread;
|
||||
}
|
||||
|
||||
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
|
||||
enc.set_input_array(w);
|
||||
enc.set_output_array(wq);
|
||||
enc.set_output_array(scales);
|
||||
enc.set_output_array(biases);
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if (dequantize_) {
|
||||
auto kernel =
|
||||
cu::affine_dequantize<DataType, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
w.data<uint8_t>(),
|
||||
inputs[1].data<DataType>(),
|
||||
inputs[2].data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
out.size());
|
||||
} else {
|
||||
auto kernel =
|
||||
cu::affine_quantize<DataType, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
w.data<DataType>(),
|
||||
out.data<uint8_t>(),
|
||||
outputs[1].data<DataType>(),
|
||||
outputs[2].data<DataType>(),
|
||||
w.size());
|
||||
}
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
w.data<T>(),
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
biases.data<T>(),
|
||||
w.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void affine_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& w,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
|
||||
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
int packs_per_int;
|
||||
switch (bits_) {
|
||||
case 3:
|
||||
case 5:
|
||||
packs_per_int = 8;
|
||||
break;
|
||||
case 6:
|
||||
packs_per_int = 4;
|
||||
break;
|
||||
default:
|
||||
packs_per_int = 8 / bits_;
|
||||
}
|
||||
|
||||
size_t size = w.size() / packs_per_int;
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
|
||||
enc.set_input_array(wq);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
biases.data<T>(),
|
||||
w.data<T>(),
|
||||
w.size());
|
||||
});
|
||||
});
|
||||
});
|
37
mlx/backend/cuda/quantized/qmm.cu
Normal file
37
mlx/backend/cuda/quantized/qmm.cu
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {} // namespace cu
|
||||
|
||||
void qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& out,
|
||||
bool transpose_,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
dispatch_bool(transpose_, [&](auto transpose) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
113
mlx/backend/cuda/quantized/quantized.cu
Normal file
113
mlx/backend/cuda/quantized/quantized.cu
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline array ensure_row_contiguous(
|
||||
const array& x,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
inline array ensure_row_contiguous_matrix(
|
||||
const array& x,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Make sure the last two dims of x and w, s, b are contiguous. This should
|
||||
// be relaxed for x.
|
||||
array x = ensure_row_contiguous_matrix(inputs[0], enc, s);
|
||||
array w = ensure_row_contiguous_matrix(inputs[1], enc, s);
|
||||
array scales = ensure_row_contiguous_matrix(inputs[2], enc, s);
|
||||
array biases = ensure_row_contiguous_matrix(inputs[3], enc, s);
|
||||
|
||||
// Extract the matmul shapes
|
||||
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
|
||||
int K = x.shape(-1);
|
||||
int M = non_batched ? x.size() / K : x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
qmm(x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
out,
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
enc,
|
||||
s);
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
if (dequantize_) {
|
||||
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
auto& w = outputs[0];
|
||||
|
||||
w.set_data(allocator::malloc(w.nbytes()));
|
||||
|
||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto& wq = outputs[0];
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
|
||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
42
mlx/backend/cuda/quantized/quantized.cuh
Normal file
42
mlx/backend/cuda/quantized/quantized.cuh
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void affine_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void affine_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& w,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& out,
|
||||
bool transpose_,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_pack_factor() {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_bytes_per_pack() {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename F>
|
||||
void dispatch_groups(int group_size, F&& f) {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
f(std::integral_constant<int, 32>{});
|
||||
break;
|
||||
case 64:
|
||||
f(std::integral_constant<int, 64>{});
|
||||
break;
|
||||
case 128:
|
||||
f(std::integral_constant<int, 128>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_bits(int bits, F&& f) {
|
||||
switch (bits) {
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
f(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
case 5:
|
||||
f(std::integral_constant<int, 5>{});
|
||||
break;
|
||||
case 6:
|
||||
f(std::integral_constant<int, 6>{});
|
||||
break;
|
||||
case 8:
|
||||
f(std::integral_constant<int, 8>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Reference in New Issue
Block a user