Compare commits

..

16 Commits

Author SHA1 Message Date
Angelos Katharopoulos
8269c9d02d Support unaligned M 2025-07-23 00:40:27 -07:00
Angelos Katharopoulos
903b40627c Add dynamic shared memory and improve qmm 2025-07-22 23:36:53 -07:00
Angelos Katharopoulos
700f7dcf01 Refactor the matmul a bit 2025-07-21 23:38:21 -07:00
Angelos Katharopoulos
6c60bd1cbf Fixed mma and working dequant 2025-07-21 04:47:42 -07:00
Angelos Katharopoulos
a64cc02a0c Somewhat working matmul primitives 2025-07-21 04:47:42 -07:00
Angelos Katharopoulos
346ae5fdb5 Refactor quantized 2025-07-21 04:47:41 -07:00
Awni Hannun
93d70419e7 [CUDA] speedup handling scalars (#2389)
* speedup scalars in cuda

* comment
2025-07-18 21:47:31 -07:00
Awni Hannun
63f663d9c6 fix cuda manylinux version to match others (#2388) 2025-07-18 21:02:16 -07:00
Awni Hannun
84b4d96efa fix release build + patch bump (#2387) 2025-07-18 14:47:37 -07:00
Awni Hannun
aec67f2fa6 patch bump (#2386) 2025-07-18 12:25:48 -07:00
Gökdeniz Gülmez
deee214a95 Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer

* update ACKNOWLEDGMENTS.md

* nits and adding it to test

* nits

* G.astype(mx.bfloat16) to G.astype(G.dtype)

* G.ndim >= 2 to assert G.ndim == 2

* remove coments

* replace with  mx.addmm

* remove comments

* format

* nits

* match muon

* fix addmm

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-07-18 12:25:28 -07:00
Cheng
45adec102c Add contiguous_copy_gpu util for copying array (#2379) 2025-07-18 06:44:25 -07:00
Cheng
31fc530c76 [CUDA] Add more ways finding CCCL headers in JIT (#2382) 2025-07-17 15:25:34 -07:00
Awni Hannun
fbb3f65a1a fix resource leaks in matmul and graph (#2383) 2025-07-17 06:50:15 -07:00
Angelos Katharopoulos
6b1b8ea91b [CUDA] Add work per thread to compile (#2368) 2025-07-17 06:47:52 -07:00
Awni Hannun
b2273733ea Test with CUDA 12.2 (#2375)
* Test with CUDA 12.0

* try older image

* fix cpu sort
2025-07-16 13:00:37 -07:00
60 changed files with 1540 additions and 379 deletions

View File

@@ -201,7 +201,7 @@ jobs:
cuda_build_and_test:
machine:
image: linux-cuda-12:default
image: linux-cuda-12:2023.11.1
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
@@ -210,7 +210,7 @@ jobs:
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
python3 -m venv env
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
@@ -272,6 +272,7 @@ jobs:
name: Build Python package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
@@ -333,6 +334,7 @@ jobs:
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
@@ -364,7 +366,7 @@ jobs:
type: string
default: ""
machine:
image: linux-cuda-12:default
image: linux-cuda-12:2024.11.1
resource_class: gpu.nvidia.small.gen2
steps:
- checkout

View File

@@ -22,7 +22,7 @@ project(
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)

View File

@@ -19,3 +19,4 @@ Common Optimizers
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -334,7 +334,9 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
@@ -426,7 +428,9 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());

View File

@@ -42,7 +42,9 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -90,7 +92,7 @@ target_compile_options(
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"70;80"
"80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -130,3 +132,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# Make Thunderkittens available
FetchContent_Declare(
kittens
GIT_REPOSITORY https://github.com/HazyResearch/ThunderKittens.git
GIT_TAG aaab847f430ed313ed466e64b25b9177babd1db8
GIT_SHALLOW TRUE)
FetchContent_MakeAvailable(kittens)
target_include_directories(mlx BEFORE PRIVATE "${kittens_SOURCE_DIR}/include")

View File

@@ -17,6 +17,52 @@ namespace cu {
constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
end_ = reinterpret_cast<void*>(
reinterpret_cast<char*>(buffer_) + small_pool_size);
next_free_ = reinterpret_cast<Block*>(buffer_);
auto num_blocks = small_pool_size / small_block_size;
auto curr = next_free_;
for (size_t i = 0; i < num_blocks - 1; ++i) {
curr->next = reinterpret_cast<Block*>(
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(buffer_));
}
void* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
next_free_ = next_free_->next;
return static_cast<void*>(b);
}
void SmallSizePool::free(void* p) {
auto b = static_cast<Block*>(p);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(void* p) {
return (p >= buffer_) && (p < end_);
}
CudaAllocator::CudaAllocator()
: buffer_cache_(
page_size,
@@ -36,7 +82,9 @@ Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size < page_size) {
if (size <= small_block_size) {
size = 8;
} else if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
@@ -53,11 +101,19 @@ Buffer CudaAllocator::malloc(size_t size) {
lock.unlock();
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
// Try the scalar pool first
if (size <= small_block_size) {
buf->data = scalar_pool_.malloc();
}
if (!buf->data) {
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
}
lock.lock();
}
active_memory_ += size;
@@ -116,7 +172,11 @@ void CudaAllocator::cuda_free(void* buf) {
return;
}
}
cudaFree(buf);
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf);
}
}
size_t CudaAllocator::get_active_memory() const {

View File

@@ -22,6 +22,28 @@ struct CudaBuffer {
size_t size;
};
class SmallSizePool {
private:
struct Block {
Block* next;
};
void* buffer_{nullptr};
Block* next_free_{nullptr};
void* end_{nullptr};
public:
SmallSizePool();
~SmallSizePool();
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
void* malloc();
void free(void* p);
bool in_pool(void* p);
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
@@ -60,6 +82,7 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();

View File

@@ -166,6 +166,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
out.size(),

View File

@@ -219,6 +219,7 @@ void binary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
@@ -235,6 +236,7 @@ void binary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
@@ -269,6 +271,7 @@ void binary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),

View File

@@ -239,6 +239,7 @@ void binary_two_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
@@ -256,6 +257,7 @@ void binary_two_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
@@ -291,6 +293,7 @@ void binary_two_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),

View File

@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
// Build function signature.
if (contiguous) {
os += "template <typename IdxT = uint32_t>\n";
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
} else {
os += "template <int NDIM, typename IdxT = uint32_t>\n";
os +=
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
}
os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) {
@@ -67,12 +68,46 @@ struct FusedKernelBuilder {
}
os += ") {\n";
// Index.
// Index. For non contiguous kernels we create a separate index
// variable per variable otherwise everyone uses `index`.
os +=
" IdxT index = cg::this_grid().thread_rank();\n"
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
" if (index >= size) {\n"
" return;\n"
" }\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " IdxT " + xname + "_idx = 0;\n";
}
os += " {\n";
os += " IdxT loc = index;\n";
os +=
" #pragma unroll\n"
" for (int i = NDIM - 1; i >= 0; i--) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
"_strides[i]);\n";
}
os +=
" loc /= shape[i];\n"
" }\n"
" }\n";
}
// Work loop
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -89,12 +124,9 @@ struct FusedKernelBuilder {
} else if (contiguous) {
value = fmt::format("{}[index]", xname);
} else {
std::string index = fmt::format(
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
xname);
value = fmt::format("{}[{}]", xname, index);
value = fmt::format("{}[{}_idx]", xname, xname);
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write tape.
@@ -113,14 +145,30 @@ struct FusedKernelBuilder {
}
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write output.
for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
}
}
os += " }\n";
os += "}\n";
}
};
@@ -156,15 +204,28 @@ void Compiled::eval_gpu(
builder.build("_strided", false);
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names = {
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
};
for (int i = 1; i <= MAX_NDIM; ++i) {
std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
kernel_names.push_back(
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
lib_name(),
i,
work_per_thread));
}
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
});
@@ -207,13 +268,21 @@ void Compiled::eval_gpu(
args.append<uint32_t>(outputs[0].data_size());
}
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
// Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
if (contiguous) {
kernel_name += fmt::format("_contiguous<{}>", index_type);
kernel_name +=
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
} else {
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@@ -224,8 +293,9 @@ void Compiled::eval_gpu(
}
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
auto [num_blocks, block_dims] =
get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
} // namespace mlx::core

View File

@@ -82,6 +82,7 @@ void copy_contiguous(
kernel,
num_blocks,
block_dims,
0,
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());

View File

@@ -79,6 +79,7 @@ void copy_general(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
data_size,
@@ -94,6 +95,7 @@ void copy_general(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
data_size,

View File

@@ -82,6 +82,7 @@ void copy_general_dynamic(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
@@ -99,6 +100,7 @@ void copy_general_dynamic(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),

View File

@@ -71,6 +71,7 @@ void copy_general_input(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
@@ -85,6 +86,7 @@ void copy_general_input(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),

View File

@@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
@@ -216,12 +215,14 @@ void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
@@ -232,6 +233,7 @@ void CommandEncoder::add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
@@ -242,6 +244,7 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
CUgraphNode node;
CHECK_CUDA_ERROR(
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));

View File

@@ -45,25 +45,34 @@ class CommandEncoder {
void set_output_array(const array& arr);
template <typename F, typename... Params>
void
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
void add_kernel_node(
F* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
Params&&... params) {
constexpr size_t num = sizeof...(Params);
void* ptrs[num];
size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)),
...);
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
}
void add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
void
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
void add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());

View File

@@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
} // namespace mlx::core

View File

@@ -52,13 +52,29 @@ const std::string& cuda_home() {
}
// Return the location of CCCL headers shipped with the distribution.
bool get_cccl_include(std::string* out) {
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
if (!std::filesystem::exists(cccl_headers)) {
return false;
}
*out = fmt::format("--include-path={}", cccl_headers.string());
return true;
const std::string& cccl_dir() {
static std::string dir = []() {
std::filesystem::path path;
#if defined(MLX_CCCL_DIR)
// First search the install dir if defined.
path = MLX_CCCL_DIR;
if (std::filesystem::exists(path)) {
return path.string();
}
#endif
// Then search dynamically from the dir of libmlx.so file.
path = current_binary_dir().parent_path() / "include" / "cccl";
if (std::filesystem::exists(path)) {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
return std::string();
}();
return dir;
}
// Get the cache directory for storing compiled results.
@@ -121,7 +137,8 @@ void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
if (cache_dir.empty()) {
return;
}
@@ -134,6 +151,9 @@ void write_cached_ptx(
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
}
// Return if |device|'s version is not newer than |major|.|minor| version.
@@ -234,8 +254,9 @@ JitModule::JitModule(
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include;
if (get_cccl_include(&cccl_include)) {
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
@@ -272,7 +293,8 @@ JitModule::JitModule(
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
// Load module.

View File

@@ -237,8 +237,7 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -267,6 +266,7 @@ void LayerNorm::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
@@ -295,9 +295,7 @@ void LayerNormVJP::eval_gpu(
return x;
}
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
return contiguous_copy_gpu(x, s);
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();
@@ -381,6 +379,7 @@ void LayerNormVJP::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),

View File

@@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
@@ -152,6 +151,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
axis_size);

View File

@@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
}
}
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul {
public:
MatMul(
@@ -43,7 +72,7 @@ class MatMul {
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()) {
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
@@ -77,20 +106,6 @@ class MatMul {
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
MatMul(
@@ -104,7 +119,6 @@ class MatMul {
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
bool c_transposed,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
@@ -126,15 +140,15 @@ class MatMul {
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
~MatMul() {
cublasLtMatrixLayoutDestroy(a_desc_);
cublasLtMatrixLayoutDestroy(b_desc_);
cublasLtMatrixLayoutDestroy(c_desc_);
cublasLtMatrixLayoutDestroy(out_desc_);
cublasLtMatmulDescDestroy(matmul_desc_);
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void run(
@@ -259,9 +273,9 @@ class MatMul {
return desc;
}
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
@@ -282,8 +296,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
array arr_copy = contiguous_copy_gpu(arr, s);
enc.add_temporary(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
@@ -389,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& c_pre = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
auto c = inputs[2];
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -404,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);
copy_gpu(c, out, CopyType::General, s);
c = out;
}
}
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
@@ -442,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
K,
N,
ldb,
c_transposed,
ldc,
batch_shape.back(),
a_batch_strides.back(),

View File

@@ -0,0 +1,108 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/matmul/tiles.cuh"
namespace mlx::core::cu {
template <typename U, typename T>
__device__ inline void
mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
/**
* Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16
* float tile.
*
* We actually perform C += A @ B.T
*/
__device__ inline void mma_t(
Tile16x16<float>& C,
Tile16x16<__nv_bfloat16>& A,
Tile16x16<__nv_bfloat16>& B) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
// D matrix
: "+f"(C.values[0].x),
"+f"(C.values[0].y),
"+f"(C.values[1].x),
"+f"(C.values[1].y)
// A matrix
: "r"(*(uint32_t*)(&A.values[0])),
"r"(*(uint32_t*)(&A.values[1])),
"r"(*(uint32_t*)(&A.values[2])),
"r"(*(uint32_t*)(&A.values[3])),
// B matrix
"r"(*(uint32_t*)(&B.values[0])),
"r"(*(uint32_t*)(&B.values[2])),
// C matrix
"f"(C.values[0].x),
"f"(C.values[0].y),
"f"(C.values[1].x),
"f"(C.values[1].y));
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
// D matrix
: "+f"(C.values[2].x),
"+f"(C.values[2].y),
"+f"(C.values[3].x),
"+f"(C.values[3].y)
// A matrix
: "r"(*(uint32_t*)(&A.values[0])),
"r"(*(uint32_t*)(&A.values[1])),
"r"(*(uint32_t*)(&A.values[2])),
"r"(*(uint32_t*)(&A.values[3])),
// B matrix
"r"(*(uint32_t*)(&B.values[1])),
"r"(*(uint32_t*)(&B.values[3])),
// C matrix
"f"(C.values[2].x),
"f"(C.values[2].y),
"f"(C.values[3].x),
"f"(C.values[3].y));
}
/**
* Multiply larger register tiles by delegating to mma_t.
*/
template <typename U, typename T, int M, int N, int K>
__device__ inline void mma_t(
RegisterTile<U, M, N>& C,
RegisterTile<T, M, K>& A,
RegisterTile<T, N, K>& B) {
constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;
constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;
constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;
MLX_UNROLL
for (int k = 0; k < TILES_K; k++) {
MLX_UNROLL
for (int m = 0; m < TILES_M; m++) {
MLX_UNROLL
for (int n = 0; n < TILES_N; n++) {
mma_t(
C.data[m * TILES_N + n],
A.data[m * TILES_K + k],
B.data[n * TILES_K + k]);
}
}
}
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,419 @@
// Copyright © 2025 Apple Inc.
#pragma once
#define MLX_UNROLL _Pragma("unroll")
namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/**
* The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp.
*
* Each thread holds 8 values. They are distributed according to
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
*
* For use instructions see the individual methods eg load().
*/
template <typename T>
struct Tile16x16 {
using T2 = Vector2_t<T>;
T2 values[4];
__device__ inline void fill(T v) {
T2 v2 = {v, v};
for (int i = 0; i < 4; i++) {
values[i] = v2;
}
}
/**
* Load a 16x16 tile from shared memory.
*
* The instruction is a bit weird in the sense that the address provided by
* each thread and the elements loaded are not the same.
*
* We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a
* result the warp provides 4*8 = 32 addresses one per row.
*
* Threads 0-7 provide the addresses for the first tile, 8-15 for the second
* and so on. For instance to load a non swizzled tile we would do
*
* base_addr + (laneid % 16) * BK + (laneid / 2) * 8
*
* See
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
*/
__device__ inline void load(uint32_t row_address) {
if constexpr (
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(*(uint32_t*)&(values[0])),
"=r"(*(uint32_t*)&(values[1])),
"=r"(*(uint32_t*)&(values[2])),
"=r"(*(uint32_t*)&(values[3]))
: "r"(row_address));
}
}
/**
* Store the tile to the address pointed to by `x`.
*
* The provided pointer is a generic pointer but this is meant to be used to
* store to global memory. For storing to shared memory we should use
* `stmatrix`.
*
* This also showcases the format of the tile quite nicely. Each register is
* holding to adjacent values. The indices are
*
* row + 0, col + 0
* row + 8, col + 0
* row + 0, col + 8
* row + 8, col + 8
*
* Given that we are dealing with Vector2_t<U> the column offsets are 4
* instead of 8.
*/
template <typename U>
__device__ inline void store_global(U* x, int N) {
using U2 = Vector2_t<U>;
U2* x2 = reinterpret_cast<U2*>(x);
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
if constexpr (std::is_same_v<U2, T2>) {
x2[(row + 0) * (N / 2) + col + 0] = values[0];
x2[(row + 0) * (N / 2) + col + 4] = values[2];
x2[(row + 8) * (N / 2) + col + 0] = values[1];
x2[(row + 8) * (N / 2) + col + 4] = values[3];
} else if constexpr (
std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
x2[(row + 0) * (N / 2) + col + 0] =
__floats2bfloat162_rn(values[0].x, values[0].y);
x2[(row + 0) * (N / 2) + col + 4] =
__floats2bfloat162_rn(values[2].x, values[2].y);
x2[(row + 8) * (N / 2) + col + 0] =
__floats2bfloat162_rn(values[1].x, values[1].y);
x2[(row + 8) * (N / 2) + col + 4] =
__floats2bfloat162_rn(values[3].x, values[3].y);
}
}
template <typename U>
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
if (row < max_rows) {
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
}
if (row + 8 < max_rows) {
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
}
}
};
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
template <typename Tile>
__device__ inline void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
template <typename U>
__device__ inline void
store_global_safe(U* x, int N, int row, int col, int max_rows) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global_safe(
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
}
}
}
};
template <typename T, int ROWS_, int COLS_>
struct SharedTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
static constexpr int NUMEL = ROWS * COLS;
// Swizzle taken from ThunderKittens.
//
// See inludes/types/shared/st.cuh
//
// I do feel that it is too math heavy and can be improved. Also the math is
// done every time although the addresses don't change from load to load. I
// guess we are expecting the compiler to figure that out.
static constexpr int swizzle_bytes =
(sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))
: (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));
T data[ROWS * COLS];
// Return a pointer to the element at (row, col) using the swizzle.
__device__ static inline T* ptr(T* ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols;
const uint64_t addr =
(uint64_t)(&ptr
[outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols]);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (T*)(addr ^ swizzle);
} else {
return ptr + row * COLS + col;
}
}
// Return the location of the element at (row, col) using the swizzle.
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols;
const uint32_t addr = ptr +
sizeof(T) *
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (addr ^ swizzle);
} else {
return ptr + sizeof(T) * (row * COLS + col);
}
}
// Convenience functions to edit elements going through the swizzle.
__device__ inline T& operator()(int row, int col) {
return *ptr(data, row, col);
}
__device__ inline void store(float4& v, int row, int col) {
*(reinterpret_cast<float4*>(ptr(data, row, col))) = v;
}
__device__ inline void store(float2& v, int row, int col) {
*(reinterpret_cast<float2*>(ptr(data, row, col))) = v;
}
__device__ inline void store(float& v, int row, int col) {
*(reinterpret_cast<float*>(ptr(data, row, col))) = v;
}
template <int N>
__device__ inline void store(T (&v)[N], int row, int col) {
if constexpr (sizeof(T) * N == 4) {
store(*(reinterpret_cast<float*>(&v[0])), row, col);
} else if constexpr (sizeof(T) * N == 8) {
store(*(reinterpret_cast<float2*>(&v[0])), row, col);
} else if constexpr (sizeof(T) * N == 16) {
store(*(reinterpret_cast<float4*>(&v[0])), row, col);
} else {
MLX_UNROLL
for (int i = 0; i < N; i++) {
*ptr(data, row, col + i) = v[i];
}
}
}
};
/**
* Load the tile from global memory by loading 16 bytes at a time and storing
* them immediately.
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load(Tile& tile, const T* x, int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
float4 tmp;
tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
/**
* Copy 16 bytes from the globale memory address pointed to by x to the smem
* address pointed to by row_address.
*
* A simple wrapper over the PTX.
*/
template <typename T>
__device__ inline void cp_async_16(uint32_t row_address, const T* x) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int4*>(x)));
}
/**
* Submit all the previous async copies to be executed.
*/
__device__ inline void cp_async_commit() {
asm volatile("cp.async.commit_group;\n" ::);
}
/**
* Wait for all the async copies to finish.
*/
__device__ inline void cp_async_wait_all() {
asm volatile("cp.async.wait_all;\n" ::);
}
/**
* The asynchronous equivalent of load.
*
* Loads the tile from global memory by submitting a bunch of async copy
* instructions. The copy won't start until commit is called and we don't have
* a guarantee it will finish until wait is called.
*
* It should be used as follows
*
* load(...)
* load(...)
* cp_async_commit()
* do_other_stuff()
* cp_async_wait_all()
* do_stuff_with_shmem()
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void
load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
cp_async_16(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
}
}
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load_async_safe(
Tile& tile,
uint32_t base_address,
const T* x,
int N,
int max_rows) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
if (row + i * STEP_ROWS < max_rows) {
cp_async_16(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
} else {
float4 tmp = {0, 0, 0, 0};
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
}
} // namespace mlx::core::cu

View File

@@ -81,7 +81,6 @@ NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)

View File

@@ -2,30 +2,17 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int group_size, int bits>
__global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
@@ -36,7 +23,8 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim = cg::this_grid().dim_threads();
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr float eps = 1e-7;
constexpr int simd_size = WARP_SIZE;
constexpr float n_bins = (1 << bits) - 1;
@@ -48,7 +36,7 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
size_t offset = tidx + grid_dim.x * size_t(tidy);
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t in_index = offset * values_per_reduce;
if (in_index >= size) {
return;
@@ -153,12 +141,13 @@ __global__ void affine_dequantize(
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim = cg::this_grid().dim_threads();
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
size_t offset = tidx + grid_dim.x * size_t(tidy);
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
@@ -238,143 +227,102 @@ __global__ void affine_dequantize(
}
} // namespace cu
namespace {
inline array ensure_row_contiguous(
const array& x,
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
} // namespace
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto w = ensure_row_contiguous(w_pre, enc, s);
enc.set_input_array(w);
if (dequantize_) {
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(out);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
enc.set_output_array(out);
enc.set_output_array(scales);
enc.set_output_array(biases);
}
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
// Treat uint32 as uint8 in kernel
int uint8_per_uint32 = 4;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
size_t size =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
// Calculate the number of elements per thread
int per_thread = group_size_ / WARP_SIZE;
size_t size = w.size() / per_thread;
// Calculate the thread grid that we need to launch
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() /= per_thread;
if (dequantize_) {
grid_shape.back() *= uint8_per_uint32;
} else {
grid_shape.back() /= per_thread;
}
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
enc.set_output_array(biases);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) {
auto kernel = cu::affine_dequantize<DataType, group_size(), bits()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<uint8_t>(),
inputs[1].data<DataType>(),
inputs[2].data<DataType>(),
out.data<DataType>(),
out.size());
} else {
auto kernel = cu::affine_quantize<DataType, group_size(), bits()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<DataType>(),
out.data<uint8_t>(),
outputs[1].data<DataType>(),
outputs[2].data<DataType>(),
w.size());
}
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.size());
});
});
});
}
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
constexpr int uint8_per_uint32 = 4;
int packs_per_int;
switch (bits_) {
case 3:
case 5:
packs_per_int = 8;
break;
case 6:
packs_per_int = 4;
break;
default:
packs_per_int = 8 / bits_;
}
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
w.size());
});
});
});

View File

@@ -0,0 +1,228 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/matmul/mma.cuh"
#include "mlx/backend/cuda/matmul/tiles.cuh"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h"
namespace mlx::core {
namespace cu {
template <int NUM_WARPS, int group_size, int bits, typename T, typename Tile>
__device__ inline void load_quantized(
Tile& tile,
const uint8_t* x,
const T* scales,
const T* biases,
int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(uint32_t) * get_pack_factor<bits>();
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
constexpr int MASK = (1 << bits) - 1;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
const int Nx = N / get_pack_factor<bits>();
const int Ng = N / group_size;
x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor<bits>());
scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
T vs[ELEMENTS_PER_LOAD];
uint32_t w = *reinterpret_cast<const uint32_t*>(x + i * STEP_ROWS * Nx);
T s = scales[i * STEP_ROWS * Ng];
T b = biases[i * STEP_ROWS * Ng];
MLX_UNROLL
for (int j = 0; j < ELEMENTS_PER_LOAD; j++) {
vs[j] = static_cast<T>((w >> (j * bits)) & MASK) * s + b;
}
tile.store(vs, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
template <
typename T,
int BM,
int BN,
int BK,
int group_size,
int bits,
bool aligned_M>
__global__ void qmm_t(
const T* x,
const uint8_t* w,
const T* scales,
const T* biases,
T* y,
int M,
int N,
int K) {
constexpr int WARPS_M = 2;
constexpr int WARPS_N = 4;
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
constexpr int WARP_STEP_M = BM / WARPS_M;
constexpr int WARP_STEP_N = BN / WARPS_N;
const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32;
const int wm = warpid / WARPS_N;
const int wn = warpid % WARPS_N;
const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N;
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&xs)[1] = *(SharedTile<T, BM, BK>(*)[1])(&shmem[0]);
SharedTile<T, BN, BK>(&ws)[1] =
*(SharedTile<T, BN, BK>(*)[1])(&shmem[1 * sizeof(T) * BM * BK]);
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
const int max_rows = M - blockIdx.y * BM;
x += blockIdx.y * BM * K;
w += blockIdx.x * BN * K / get_pack_factor<bits>();
scales += blockIdx.x * BN * K / group_size;
biases += blockIdx.x * BN * K / group_size;
y += blockIdx.y * BM * N + blockIdx.x * BN;
C.fill(0);
int tic = 0;
uint32_t base_addr_xs[1], base_addr_ws[1];
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
if (aligned_M || max_rows >= BM) {
for (int k_block = 0; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
cp_async_commit();
load_quantized<NUM_WARPS, group_size, bits>(
ws[tic],
w + k_block / get_pack_factor<bits>(),
scales + k_block / group_size,
biases + k_block / group_size,
K);
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
xs[tic],
base_addr_xs[tic],
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
ws[tic],
base_addr_ws[tic],
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
}
C.store_global(y, N, offset_m, offset_n);
} else {
for (int k_block = 0; k_block < K; k_block += BK) {
load_async_safe<NUM_WARPS>(
xs[tic], base_addr_xs[tic], x + k_block, K, max_rows);
cp_async_commit();
load_quantized<NUM_WARPS, group_size, bits>(
ws[tic],
w + k_block / get_pack_factor<bits>(),
scales + k_block / group_size,
biases + k_block / group_size,
K);
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
xs[tic],
base_addr_xs[tic],
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
ws[tic],
base_addr_ws[tic],
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
}
C.store_global_safe(y, N, offset_m, offset_n, max_rows);
}
}
} // namespace cu
void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
bool transpose_,
int group_size_,
int bits_,
int M,
int N,
int K,
cu::CommandEncoder& enc,
const Stream& s) {
if (x.dtype() != bfloat16) {
throw std::invalid_argument("[qmm] Only bfloat16 is supported for now");
}
if (!transpose_) {
throw std::invalid_argument(
"[qmm] Only transposed matmul is supported for now");
}
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
auto kernel =
cu::qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, true>;
if (M % BM != 0) {
kernel = cu::
qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, false>;
}
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
enc.add_kernel_node(
kernel,
grid,
2 * 4 * 32,
1 * sizeof(DataType) * (BM * BK + BN * BK),
x.data<DataType>(),
w.data<uint8_t>(),
scales.data<DataType>(),
biases.data<DataType>(),
out.data<DataType>(),
M,
N,
K);
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,113 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
}
} // namespace
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
out.set_data(allocator::malloc(out.nbytes()));
// Make sure the last two dims of x and w, s, b are contiguous. This should
// be relaxed for x.
array x = ensure_row_contiguous_matrix(inputs[0], enc, s);
array w = ensure_row_contiguous_matrix(inputs[1], enc, s);
array scales = ensure_row_contiguous_matrix(inputs[2], enc, s);
array biases = ensure_row_contiguous_matrix(inputs[3], enc, s);
// Extract the matmul shapes
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
int K = x.shape(-1);
int M = non_batched ? x.size() / K : x.shape(-2);
int N = out.shape(-1);
qmm(x,
w,
scales,
biases,
out,
transpose_,
group_size_,
bits_,
M,
N,
K,
enc,
s);
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,42 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
bool transpose_,
int group_size_,
int bits_,
int M,
int N,
int K,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -0,0 +1,59 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core {
namespace cu {
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
} // namespace cu
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
} // namespace mlx::core

View File

@@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbitsc,
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
@@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbits,
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,

View File

@@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
array in_copy = contiguous_copy_gpu(in, s);
encoder.add_temporary(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);

View File

@@ -120,6 +120,7 @@ void all_reduce(
kernel,
blocks,
threads,
0,
static_cast<T*>(indata),
intermediate.data<U>(),
block_step,
@@ -146,6 +147,7 @@ void all_reduce(
kernel,
blocks,
threads,
0,
static_cast<T*>(indata),
out.data<U>(),
block_step,

View File

@@ -230,7 +230,7 @@ void col_reduce_looped(
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel, grid, blocks, indata, out.data<U>(), args);
kernel, grid, blocks, 0, indata, out.data<U>(), args);
});
});
});

View File

@@ -41,7 +41,8 @@ void init_reduce(
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024;
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
encoder.add_kernel_node(
kernel, grid, block, 0, out.data<U>(), out.size());
});
});
}

View File

@@ -269,7 +269,7 @@ void row_reduce_simple(
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, indata, out.data<U>(), out.size(), size);
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
});
});
}
@@ -322,7 +322,7 @@ void row_reduce_looped(
});
encoder.add_kernel_node(
kernel, grid, block, indata, out.data<U>(), out.size(), args);
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
});
});
}

View File

@@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -233,6 +232,7 @@ void RMSNorm::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
out.data<DataType>(),
@@ -259,9 +259,7 @@ void RMSNormVJP::eval_gpu(
return x;
}
copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
return contiguous_copy_gpu(x, s);
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable();
@@ -330,6 +328,7 @@ void RMSNormVJP::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),

View File

@@ -325,6 +325,7 @@ void RoPE::eval_gpu(
kernel,
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
@@ -341,6 +342,7 @@ void RoPE::eval_gpu(
kernel,
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
@@ -360,6 +362,7 @@ void RoPE::eval_gpu(
kernel,
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
@@ -381,6 +384,7 @@ void RoPE::eval_gpu(
kernel,
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),

View File

@@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags());
}
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
in = contiguous_copy_gpu(in, s);
out.copy_shared_buffer(in);
}
@@ -416,6 +414,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
in.data_size() / axis_size,
block_dim,
0,
in.data<T>(),
out.data<U>(),
axis_size);
@@ -445,6 +444,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
num_blocks,
block_dim,
0,
in.data<T>(),
out.data<U>(),
axis_size,

View File

@@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -153,6 +152,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
axis_size);

View File

@@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
if (!is_segmented_sort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = array(trans.shape(), trans.dtype(), nullptr, {});
copy_gpu(trans, in, CopyType::General, s);
in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);

View File

@@ -133,6 +133,7 @@ void ternary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
@@ -151,6 +152,7 @@ void ternary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
@@ -180,6 +182,7 @@ void ternary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),

View File

@@ -142,6 +142,7 @@ void unary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
@@ -154,6 +155,7 @@ void unary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size(),

View File

@@ -46,4 +46,10 @@ void copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
array contiguous_copy_gpu(const array& arr, const Stream& s) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -43,4 +43,7 @@ void copy_gpu_inplace(
// Fill the output with the scalar val
void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
} // namespace mlx::core

View File

@@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
// Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
array wt_transpose = contiguous_copy_gpu(wt_view, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
@@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
auto wt = inputs[1];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
in = contiguous_copy_gpu(in, s);
copies.push_back(in);
}
if (!wt.flags().row_contiguous) {
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
copy_gpu(wt, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
wt = arr_copy;
wt = contiguous_copy_gpu(wt, s);
copies.push_back(wt);
}
// 3D conv

View File

@@ -25,8 +25,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}

View File

@@ -33,8 +33,7 @@ std::tuple<bool, int64_t, array> check_transpose(
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
array arr_copy = contiguous_copy_gpu(arr, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
@@ -43,8 +42,7 @@ std::tuple<bool, int64_t, array> check_transpose(
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
@@ -75,8 +73,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
}
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
@@ -1894,8 +1891,7 @@ void segmented_mm(
return std::make_tuple(false, x);
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(true, x_copy);
};

View File

@@ -40,8 +40,7 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -107,9 +106,7 @@ void RMSNormVJP::eval_gpu(
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
@@ -241,8 +238,7 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -319,8 +315,7 @@ void LayerNormVJP::eval_gpu(
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();

View File

@@ -20,8 +20,7 @@ namespace {
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
@@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix(
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}

View File

@@ -989,8 +989,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// input for the axes with stride smaller than the minimum reduction
// stride.
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
array in_copy = contiguous_copy_gpu(in, s);
d.add_temporary(in_copy, s.index);
in = in_copy;
plan = get_reduction_plan(in, axes_);

View File

@@ -398,8 +398,7 @@ void ScaledDotProductAttention::eval_gpu(
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
array arr_copy = contiguous_copy_gpu(arr, s);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {

View File

@@ -30,9 +30,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags());
}
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
in = contiguous_copy_gpu(in, s);
out.copy_shared_buffer(in);
}

View File

@@ -35,8 +35,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
array x_copy = contiguous_copy_gpu(x, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 3
#define MLX_VERSION_PATCH 5
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -849,28 +849,28 @@ class Adafactor(Optimizer):
class Muon(Optimizer):
r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz.
r"""The Muon optimizer.
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has
the advantage that it can be stably run in bfloat16 on the GPU.
For more details, see: https://kellerjordan.github.io/posts/muon/
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
original implementation: `Muon: An optimizer for hidden layers in neural
networks <https://kellerjordan.github.io/posts/muon/>`_
Note:
- This optimizer may not be optimal for the embedding layer, the final fully connected layer,
or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW).
- For 4D convolutional filters, it works by flattening their last dimensions.
- Muon may be sub-optimal for the embedding layer, the final fully
connected layer, or any 0D/1D parameters. Those should be optimized
by a different method (e.g., :class:`AdamW`).
- For 4D convolutional filters, it works by flattening their last
dimensions.
Args:
learning_rate (float or callable): The learning rate used by the internal SGD.
learning_rate (float or callable): The learning rate.
momentum (float, optional): The momentum strength. Default: ``0.95``
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance.
Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization.
Default: ``5``
weight_decay (float, optional): The weight decay (L2 penalty).
Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
better performance. Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
orthogonalization. Default: ``5``
"""
def __init__(
@@ -882,7 +882,7 @@ class Muon(Optimizer):
ns_steps: int = 5,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
@@ -893,80 +893,59 @@ class Muon(Optimizer):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2
def _zeropower_via_newtonschulz5(self, X, steps: int):
assert (
X.ndim == 2
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(mx.bfloat16)
transpose_needed = G.shape[-2] > G.shape[-1]
transpose_needed = X.shape[-2] > X.shape[-1]
if transpose_needed:
X = X.T
# Ensure spectral norm is at most 1
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
X = X / norm
# Perform the NS iterations
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
for _ in range(steps):
A = X @ X.T
B = b * A + c * (A @ A)
X = a * X + B @ X
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
if transpose_needed:
X = X.T
return X
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
# Apply weight decay
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
# Get effective gradient
if self.nesterov:
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
update = gradient * (1 - self.momentum) + v * self.momentum
else:
effective_grad = v
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
if effective_grad.ndim < 2:
orthogonalized_grad = effective_grad
scale_factor = 1.0
else:
# Save original shape for 4D conv filters
original_shape = effective_grad.shape
reshape_needed = effective_grad.ndim > 2
update = v
lr = self.learning_rate.astype(gradient.dtype)
if update.ndim >= 2:
original_shape = update.shape
reshape_needed = update.ndim > 2
if reshape_needed:
effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1))
# Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps)
# Reshape back if needed
update = mx.reshape(update, (update.shape[0], -1))
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
if reshape_needed:
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
# Calculate scaling factor
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
update = mx.reshape(update, original_shape)
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
return parameter - lr * update
def clip_grad_norm(grads, max_norm):

View File

@@ -1,7 +1,7 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_39_x86_64 \
--plat manylinux_2_35_x86_64 \
--exclude libcublas* \
--exclude libnvrtc* \
-w wheel_tmp

View File

@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
return lambda c, a, b: alpha * (a @ b) + beta * c

View File

@@ -307,7 +307,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
# Test update
updated_params = optim.apply_gradients(grads, params)
# Check that shapes are preserved
self.assertTrue(
tree_equal(
@@ -316,7 +316,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
updated_params,
)
)
# Check that parameters actually changed
self.assertFalse(
tree_equal(
@@ -325,11 +325,11 @@ class TestOptimizers(mlx_tests.MLXTestCase):
updated_params,
)
)
# Test with different configurations
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
optim_no_nesterov.apply_gradients(grads, params)
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
optim_no_momentum.apply_gradients(grads, params)

View File

@@ -39,6 +39,14 @@ target_sources(
linalg_tests.cpp
${METAL_TEST_SOURCES})
if(MLX_BUILD_CUDA)
# Find the CCCL headers in install dir.
target_compile_definitions(
mlx
PRIVATE
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl")
endif()
target_link_libraries(tests PRIVATE mlx doctest)
doctest_discover_tests(tests)
add_test(NAME tests COMMAND tests)