Compare commits

...

3 Commits

Author SHA1 Message Date
Awni Hannun
d32519c8ee fix gemv regression (#2445) 2025-07-30 14:23:01 -07:00
Awni Hannun
b405591249 fix circular reference (#2443) 2025-07-30 09:37:44 -07:00
Angelos Katharopoulos
3bf81ed1bd [CUDA] Quantized refactoring (#2442) 2025-07-30 08:27:20 -07:00
10 changed files with 330 additions and 184 deletions

View File

@@ -46,7 +46,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)

View File

@@ -43,10 +43,18 @@ struct alignas(sizeof(T) * N) AlignedVector {
};
template <int N, typename T>
inline __device__ bool is_aligned(T* x) {
inline __host__ __device__ bool is_aligned(T* x) {
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> unsafe_load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
@@ -101,6 +109,13 @@ inline __device__ AlignedVector<T, N> load_vector(
}
}
template <int N, typename T>
inline __device__ void
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {

View File

@@ -27,8 +27,9 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
auto local_mat =
unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
@@ -127,9 +128,13 @@ void gemv(
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
int n_per_t = 4;
while (K % (n_per_t * WARP_SIZE) != 0) {
n_per_t >>= 1;
int n_per_t;
if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
n_per_t = 4;
} else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
n_per_t = 2;
} else {
n_per_t = 1;
}
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) {

View File

@@ -2,30 +2,17 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int group_size, int bits>
__global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
@@ -240,140 +227,100 @@ __global__ void affine_dequantize(
}
} // namespace cu
namespace {
inline array ensure_row_contiguous(
const array& x,
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
} // namespace
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto w = ensure_row_contiguous(w_pre, enc, s);
enc.set_input_array(w);
if (dequantize_) {
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(out);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
enc.set_output_array(out);
enc.set_output_array(scales);
enc.set_output_array(biases);
}
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
// Treat uint32 as uint8 in kernel
int uint8_per_uint32 = 4;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
size_t size =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
// Calculate the number of elements per thread
int per_thread = group_size_ / WARP_SIZE;
size_t size = w.size() / per_thread;
// Calculate the thread grid that we need to launch
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() /= per_thread;
if (dequantize_) {
grid_shape.back() *= uint8_per_uint32;
} else {
grid_shape.back() /= per_thread;
}
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
enc.set_output_array(biases);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) {
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
cu::affine_dequantize<DataType, group_size.value, bits.value>,
num_blocks,
block_dims,
w.data<uint8_t>(),
inputs[1].data<DataType>(),
inputs[2].data<DataType>(),
out.data<DataType>(),
out.size());
} else {
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
cu::affine_quantize<DataType, group_size.value, bits.value>,
num_blocks,
block_dims,
w.data<DataType>(),
out.data<uint8_t>(),
outputs[1].data<DataType>(),
outputs[2].data<DataType>(),
w.size());
}
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.size());
});
});
});
}
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
constexpr int uint8_per_uint32 = 4;
int packs_per_int;
switch (bits_) {
case 3:
case 5:
packs_per_int = 8;
break;
case 6:
packs_per_int = 4;
break;
default:
packs_per_int = 8 / bits_;
}
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
w.size());
});
});
});

View File

@@ -0,0 +1,72 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
namespace mlx::core {
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
}
} // namespace
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -0,0 +1,59 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core {
namespace cu {
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
} // namespace cu
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
} // namespace mlx::core

View File

@@ -399,41 +399,7 @@ class Module(dict):
Returns:
The module instance after updating the submodules.
"""
def apply(dst, modules):
if isinstance(modules, dict):
for k in modules:
if k in dst:
current_value = dst[k]
new_value = modules[k]
if self.is_module(current_value) and self.is_module(new_value):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules)
_update_modules(self, modules, strict)
return self
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
@@ -639,6 +605,36 @@ class Module(dict):
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
def _update_modules(dst, modules, strict):
if isinstance(modules, dict):
for k in modules:
if k in dst:
current_value = dst[k]
new_value = modules[k]
if Module.is_module(current_value) and Module.is_module(new_value):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
_update_modules(current_value, new_value, strict)
elif strict and new_value != {}:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f'Module does not have sub-module named "{k}".')
elif isinstance(modules, list):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if Module.is_module(current_value) and Module.is_module(new_value):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
_update_modules(current_value, new_value, strict)
elif strict and new_value != {}:
raise ValueError(f"Received invalid type: {type(new_value).__name__}.")
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)

View File

@@ -47,7 +47,7 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
def test_matmul_unaligned(self):
if not mx.metal.is_available():
if not mx.is_available(mx.gpu):
return
for dtype in self.dtypes:
@@ -61,8 +61,15 @@ class TestBlas(mlx_tests.MLXTestCase):
shape_b = (dim + p, dim + p)
self.__gemm_test(shape_a, shape_b, np_dtype)
def test_matvec_unaligned(self):
a = mx.random.normal(shape=(4, 128))
b = mx.random.normal(shape=(129,))[1:]
out = a @ b
np_out = np.array(a) @ np.array(b)
self.assertTrue(np.allclose(out, np_out))
def test_matmul_shapes(self):
if not mx.metal.is_available():
if not mx.is_available(mx.gpu):
return
shapes = [
@@ -1274,7 +1281,7 @@ class TestBlas(mlx_tests.MLXTestCase):
def test_gemv_gemm_same_precision(self):
mx.random.seed(0)
N = 256
if mx.metal.is_available():
if mx.is_available(mx.gpu):
t = mx.bfloat16
a = mx.random.normal([1, N]).astype(t)
b = mx.concatenate([a, a], axis=0).astype(t)

View File

@@ -279,6 +279,23 @@ class TestBase(mlx_tests.MLXTestCase):
del m.weight
self.assertFalse(hasattr(m, "weight"))
def test_circular_leaks(self):
y = mx.random.uniform(1)
mx.eval(y)
def make_and_update():
model = nn.Linear(1024, 512)
mx.eval(model.parameters())
leaves = {}
model.update_modules(leaves)
mx.synchronize()
pre = mx.get_active_memory()
make_and_update()
mx.synchronize()
post = mx.get_active_memory()
self.assertEqual(pre, post)
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):