Compare commits

...

3 Commits

Author SHA1 Message Date
Angelos Katharopoulos
6c60bd1cbf Fixed mma and working dequant 2025-07-21 04:47:42 -07:00
Angelos Katharopoulos
a64cc02a0c Somewhat working matmul primitives 2025-07-21 04:47:42 -07:00
Angelos Katharopoulos
346ae5fdb5 Refactor quantized 2025-07-21 04:47:41 -07:00
8 changed files with 794 additions and 147 deletions

View File

@@ -22,7 +22,7 @@ project(
# ----------------------------- Setup ----------------------------- # ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER) set(CMAKE_INSTALL_MESSAGE NEVER)

View File

@@ -42,7 +42,9 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -90,7 +92,7 @@ target_compile_options(
# Compute capability 7 is required for synchronization between CPU/GPU with # Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain. # managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES set(MLX_CUDA_ARCHITECTURES
"70;80" "80"
CACHE STRING "CUDA architectures") CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -130,3 +132,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
# Install CCCL headers for JIT. # Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# Make Thunderkittens available
FetchContent_Declare(
kittens
GIT_REPOSITORY https://github.com/HazyResearch/ThunderKittens.git
GIT_TAG aaab847f430ed313ed466e64b25b9177babd1db8
GIT_SHALLOW TRUE)
FetchContent_MakeAvailable(kittens)
target_include_directories(mlx BEFORE PRIVATE "${kittens_SOURCE_DIR}/include")

View File

@@ -81,7 +81,6 @@ NO_GPU(Hadamard)
NO_GPU(Load) NO_GPU(Load)
NO_GPU_MULTI(LUF) NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF) NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM) NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD) NO_GPU_MULTI(SVD)
NO_GPU(Inverse) NO_GPU(Inverse)

View File

@@ -2,30 +2,17 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
__global__ void __global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
@@ -240,144 +227,100 @@ __global__ void affine_dequantize(
} }
} // namespace cu } // namespace cu
namespace {
inline array ensure_row_contiguous( void affine_quantize(
const array& x, const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc, cu::CommandEncoder& enc,
const Stream& s) { const Stream& s) {
if (!x.flags().row_contiguous) { // Calculate the number of elements per thread
array x_copy = contiguous_copy_gpu(x, s); int per_thread = group_size_ / WARP_SIZE;
enc.add_temporary(x_copy); size_t size = w.size() / per_thread;
return x_copy;
} else {
return x;
}
}
} // namespace
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto w = ensure_row_contiguous(w_pre, enc, s);
enc.set_input_array(w);
if (dequantize_) {
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(out);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
enc.set_output_array(out);
enc.set_output_array(scales);
enc.set_output_array(biases);
}
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
// Treat uint32 as uint8 in kernel
int uint8_per_uint32 = 4;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
size_t size =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
// Calculate the thread grid that we need to launch
bool large = size > UINT_MAX; bool large = size > UINT_MAX;
auto grid_shape = w.shape(); auto grid_shape = w.shape();
if (dequantize_) {
grid_shape.back() *= uint8_per_uint32;
} else {
grid_shape.back() /= per_thread; grid_shape.back() /= per_thread;
}
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) { enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
enc.set_output_array(biases);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) { dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) { dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) { auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto kernel =
cu::affine_dequantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large); get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node( enc.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
w.data<uint8_t>(), w.data<T>(),
inputs[1].data<DataType>(), wq.data<uint8_t>(),
inputs[2].data<DataType>(), scales.data<T>(),
out.data<DataType>(), biases.data<T>(),
out.size());
} else {
auto kernel =
cu::affine_quantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<DataType>(),
out.data<uint8_t>(),
outputs[1].data<DataType>(),
outputs[2].data<DataType>(),
w.size()); w.size());
});
});
});
} }
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
constexpr int uint8_per_uint32 = 4;
int packs_per_int;
switch (bits_) {
case 3:
case 5:
packs_per_int = 8;
break;
case 6:
packs_per_int = 4;
break;
default:
packs_per_int = 8 / bits_;
}
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
w.size());
}); });
}); });
}); });

View File

@@ -0,0 +1,480 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h"
namespace mlx::core {
namespace cu {
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
template <typename T>
struct Tile16x16 {
using T2 = Vector2_t<T>;
T2 values[4];
__device__ inline void clear() {
for (int i = 0; i < 4; i++) {
values[i] = static_cast<T2>(0);
}
}
__device__ inline void load(uint32_t src_address) {
if constexpr (
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(*(uint32_t*)&(values[0])),
"=r"(*(uint32_t*)&(values[1])),
"=r"(*(uint32_t*)&(values[2])),
"=r"(*(uint32_t*)&(values[3]))
: "r"(src_address));
}
}
__device__ inline void store(uint32_t dst_address) {
if constexpr (
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
asm volatile(
"stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(*(uint32_t*)&(values[0])),
"=r"(*(uint32_t*)&(values[1])),
"=r"(*(uint32_t*)&(values[2])),
"=r"(*(uint32_t*)&(values[3]))
: "r"(dst_address));
} else {
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
const uint32_t a = dst_address + ((row + 0) * 8 + col + 0) * sizeof(T2);
const uint32_t b = dst_address + ((row + 0) * 8 + col + 4) * sizeof(T2);
const uint32_t c = dst_address + ((row + 8) * 8 + col + 0) * sizeof(T2);
const uint32_t d = dst_address + ((row + 8) * 8 + col + 4) * sizeof(T2);
if constexpr (sizeof(T2) == 4) {
asm volatile("st.shared.b32 [%1], %0;\n"
:
: "r"(*(uint32_t*)&(values[0])), "r"(a));
asm volatile("st.shared.b32 [%1], %0;\n"
:
: "r"(*(uint32_t*)&(values[2])), "r"(b));
asm volatile("st.shared.b32 [%1], %0;\n"
:
: "r"(*(uint32_t*)&(values[1])), "r"(c));
asm volatile("st.shared.b32 [%1], %0;\n"
:
: "r"(*(uint32_t*)&(values[3])), "r"(d));
} else if constexpr (sizeof(T2) == 8) {
asm volatile("st.shared.b64 [%1], %0;\n"
:
: "r"(*(uint64_t*)&(values[0])), "r"(a));
asm volatile("st.shared.b64 [%1], %0;\n"
:
: "r"(*(uint64_t*)&(values[2])), "r"(b));
asm volatile("st.shared.b64 [%1], %0;\n"
:
: "r"(*(uint64_t*)&(values[1])), "r"(c));
asm volatile("st.shared.b64 [%1], %0;\n"
:
: "r"(*(uint64_t*)&(values[3])), "r"(d));
} else if constexpr (sizeof(T2) == 16) {
asm volatile("st.shared.b128 [%1], %0;\n"
:
: "r"(*(__int128*)&(values[0])), "r"(a));
asm volatile("st.shared.b128 [%1], %0;\n"
:
: "r"(*(__int128*)&(values[2])), "r"(b));
asm volatile("st.shared.b128 [%1], %0;\n"
:
: "r"(*(__int128*)&(values[1])), "r"(c));
asm volatile("st.shared.b128 [%1], %0;\n"
:
: "r"(*(__int128*)&(values[3])), "r"(d));
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N) {
using U2 = Vector2_t<U>;
U2* x2 = reinterpret_cast<U2*>(x);
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
if constexpr (std::is_same_v<U2, T2>) {
x2[(row + 0) * (N / 2) + col + 0] = values[0];
x2[(row + 0) * (N / 2) + col + 4] = values[2];
x2[(row + 8) * (N / 2) + col + 0] = values[1];
x2[(row + 8) * (N / 2) + col + 4] = values[3];
} else if constexpr (
std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
x2[(row + 0) * (N / 2) + col + 0] =
__floats2bfloat162_rn(values[0].x, values[0].y);
x2[(row + 0) * (N / 2) + col + 4] =
__floats2bfloat162_rn(values[2].x, values[2].y);
x2[(row + 8) * (N / 2) + col + 0] =
__floats2bfloat162_rn(values[1].x, values[1].y);
x2[(row + 8) * (N / 2) + col + 4] =
__floats2bfloat162_rn(values[3].x, values[3].y);
}
}
};
template <typename T, int ROWS, int COLS>
struct __align__(16) SharedTile {
static constexpr int TILES_R = ROWS / 16;
static constexpr int TILES_C = COLS / 16;
static constexpr int NUM_ELEMENTS = ROWS * COLS;
static constexpr int swizzle_bytes =
(sizeof(T) == 2 ? (TILES_C % 4 == 0 ? 128 : (TILES_C % 2 == 0 ? 64 : 32))
: (sizeof(T) == 4 ? (TILES_C % 2 == 0 ? 128 : 64) : 0));
T data[ROWS * COLS];
__device__ static inline T* idx(T* ptr, int2 coord) {
if constexpr (swizzle_bytes > 0) {
int r = coord.x, c = coord.y;
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = c / subtile_cols;
const uint64_t addr =
(uint64_t)(&ptr
[outer_idx * ROWS * subtile_cols + r * subtile_cols +
c % subtile_cols]);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (T*)(addr ^ swizzle);
} else {
return ptr + coord.y * COLS + coord.x;
}
}
__device__ static inline uint32_t idx(uint32_t ptr, int2 coord) {
if constexpr (swizzle_bytes > 0) {
int r = coord.x, c = coord.y;
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = c / subtile_cols;
const uint32_t addr = ptr +
sizeof(T) *
(outer_idx * ROWS * subtile_cols + r * subtile_cols +
c % subtile_cols);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (addr ^ swizzle);
} else {
return ptr + sizeof(T) * (coord.y * COLS + coord.x);
}
}
__device__ inline T& operator[](int2 coord) {
return *idx(&data[0], coord);
}
__device__ inline void store(float4& v, int2 coord) {
*(reinterpret_cast<float4*>(idx(data, coord))) = v;
}
__device__ inline void store(float2& v, int2 coord) {
*(reinterpret_cast<float2*>(idx(data, coord))) = v;
}
__device__ inline void store(float& v, int2 coord) {
*(reinterpret_cast<float*>(idx(data, coord))) = v;
}
template <int N>
__device__ inline void store(T (&v)[N], int2 coord) {
if constexpr (sizeof(T) * N == 4) {
store(*(reinterpret_cast<float*>(&v[0])), coord);
} else if constexpr (sizeof(T) * N == 8) {
store(*(reinterpret_cast<float2*>(&v[0])), coord);
} else if constexpr (sizeof(T) * N == 16) {
store(*(reinterpret_cast<float4*>(&v[0])), coord);
} else {
#pragma unroll
for (int i = 0; i < N; i++) {
*idx(data, {coord.x, coord.y + i}) = v[i];
}
}
}
template <int NUM_WARPS>
__device__ inline void load(const T* x, int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = NUM_ELEMENTS / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
#pragma unroll
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
float4 tmp;
tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
store(tmp, {row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD});
}
}
template <int NUM_WARPS, int group_size, int bits>
__device__ inline void
load_quantized(const uint8_t* x, const T* scales, const T* biases, int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD =
sizeof(uint32_t) * get_pack_factor<bits>();
constexpr int NUM_LOADS = NUM_ELEMENTS / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
constexpr int MASK = (1 << bits) - 1;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
const int Nx = N / get_pack_factor<bits>();
const int Ng = N / group_size;
x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor<bits>());
scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
#pragma unroll
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
T vs[ELEMENTS_PER_LOAD];
uint32_t w = *reinterpret_cast<const uint32_t*>(x + i * STEP_ROWS * Nx);
T s = scales[i * STEP_ROWS * Ng];
T b = biases[i * STEP_ROWS * Ng];
#pragma unroll
for (int j = 0; j < ELEMENTS_PER_LOAD; j++) {
vs[j] = static_cast<T>((w >> (j * bits)) & MASK) * s + b;
}
store(vs, {row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD});
}
}
};
template <typename TileAccum, typename Tile>
__device__ inline void mma(TileAccum& C, Tile& A, Tile& B) {}
__device__ inline void mma(
Tile16x16<float>& C,
Tile16x16<__nv_bfloat16>& A,
Tile16x16<__nv_bfloat16>& B) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
// D matrix
: "+f"(C.values[0].x),
"+f"(C.values[0].y),
"+f"(C.values[1].x),
"+f"(C.values[1].y)
// A matrix
: "r"(*(uint32_t*)(&A.values[0])),
"r"(*(uint32_t*)(&A.values[1])),
"r"(*(uint32_t*)(&A.values[2])),
"r"(*(uint32_t*)(&A.values[3])),
// B matrix
"r"(*(uint32_t*)(&B.values[0])),
"r"(*(uint32_t*)(&B.values[2])),
// C matrix
"f"(C.values[0].x),
"f"(C.values[0].y),
"f"(C.values[1].x),
"f"(C.values[1].y));
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
// D matrix
: "+f"(C.values[2].x),
"+f"(C.values[2].y),
"+f"(C.values[3].x),
"+f"(C.values[3].y)
// A matrix
: "r"(*(uint32_t*)(&A.values[0])),
"r"(*(uint32_t*)(&A.values[1])),
"r"(*(uint32_t*)(&A.values[2])),
"r"(*(uint32_t*)(&A.values[3])),
// B matrix
"r"(*(uint32_t*)(&B.values[1])),
"r"(*(uint32_t*)(&B.values[3])),
// C matrix
"f"(C.values[2].x),
"f"(C.values[2].y),
"f"(C.values[3].x),
"f"(C.values[3].y));
}
template <typename T, int BM, int BN, int BK, int group_size, int bits>
__global__ void qmm(
const T* x,
const uint8_t* w,
const T* scales,
const T* biases,
T* y,
int M,
int N,
int K) {
constexpr int NUM_WARPS = 4;
constexpr int WARP_M = (BM / 16) / (NUM_WARPS / 2);
constexpr int WARP_N = (BN / 16) / (NUM_WARPS / 2);
constexpr int WARP_K = BK / 16;
constexpr int WARP_STEP_M = WARP_M * 16;
constexpr int WARP_STEP_N = WARP_N * 16;
const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32;
const int offset_m = (warpid / 2) * WARP_STEP_M;
const int offset_n = (warpid % 2) * WARP_STEP_N;
__shared__ SharedTile<T, BM, BK> xs;
__shared__ SharedTile<T, BN, BK> ws;
Tile16x16<float> C[WARP_M * WARP_N];
Tile16x16<T> A[WARP_M];
Tile16x16<T> B[WARP_N];
x += blockIdx.y * BM * K;
w += blockIdx.x * BN * K / get_pack_factor<bits>();
scales += blockIdx.x * BN * K / group_size;
biases += blockIdx.x * BN * K / group_size;
y += blockIdx.y * BM * N + blockIdx.x * BN;
#pragma unroll
for (int i = 0; i < WARP_M * WARP_N; i++) {
C[i].clear();
}
uint32_t base_addr_xs = __cvta_generic_to_shared(&xs.data[0]);
uint32_t base_addr_ws = __cvta_generic_to_shared(&ws.data[0]);
for (int k_block = 0; k_block < K; k_block += BK) {
xs.load<NUM_WARPS>(x + k_block, K);
ws.load_quantized<NUM_WARPS, group_size, bits>(
w + k_block / get_pack_factor<bits>(),
scales + k_block / group_size,
biases + k_block / group_size,
K);
__syncthreads();
#pragma unroll
for (int k = 0; k < WARP_K; k++) {
#pragma unroll
for (int i = 0; i < WARP_M; i++) {
A[i].load(xs.idx(
base_addr_xs,
{offset_m + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8}));
}
#pragma unroll
for (int i = 0; i < WARP_N; i++) {
B[i].load(ws.idx(
base_addr_ws,
{offset_n + i * 16 + laneid % 16, k * 16 + laneid / 16 * 8}));
}
#pragma unroll
for (int i = 0; i < WARP_M; i++) {
#pragma unroll
for (int j = 0; j < WARP_N; j++) {
mma(C[i * WARP_N + j], A[i], B[j]);
}
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < WARP_M; i++) {
#pragma unroll
for (int j = 0; j < WARP_N; j++) {
C[i * WARP_N + j].store_global(
y + (offset_m + i * 16) * N + offset_n + j * 16, N);
}
}
}
} // namespace cu
void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
bool transpose_,
int group_size_,
int bits_,
int M,
int N,
int K,
cu::CommandEncoder& enc,
const Stream& s) {
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
// dispatch_groups(group_size_, [&](auto group_size) {
// dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 64;
constexpr int BN = 64;
constexpr int BK = 32;
auto kernel = cu::qmm<DataType, BM, BN, BK, 64, 4>;
dim3 grid(N / BN, M / BM);
enc.add_kernel_node(
kernel,
grid,
128,
x.data<DataType>(),
w.data<uint8_t>(),
scales.data<DataType>(),
biases.data<DataType>(),
out.data<DataType>(),
M,
N,
K);
//});
//});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,113 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
}
} // namespace
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
out.set_data(allocator::malloc(out.nbytes()));
// Make sure the last two dims of x and w, s, b are contiguous. This should
// be relaxed for x.
array x = ensure_row_contiguous_matrix(inputs[0], enc, s);
array w = ensure_row_contiguous_matrix(inputs[1], enc, s);
array scales = ensure_row_contiguous_matrix(inputs[2], enc, s);
array biases = ensure_row_contiguous_matrix(inputs[3], enc, s);
// Extract the matmul shapes
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
int K = x.shape(-1);
int M = non_batched ? x.size() / K : x.shape(-2);
int N = out.shape(-1);
qmm(x,
w,
scales,
biases,
out,
transpose_,
group_size_,
bits_,
M,
N,
K,
enc,
s);
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,42 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
bool transpose_,
int group_size_,
int bits_,
int M,
int N,
int K,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -0,0 +1,59 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core {
namespace cu {
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
} // namespace cu
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
} // namespace mlx::core