Compare commits

..

14 Commits

Author SHA1 Message Date
Awni Hannun
7f39e9c299 nits 2025-07-17 06:26:43 -07:00
Gökdeniz Gülmez
baad6e392b Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-17 13:07:54 +02:00
Gökdeniz Gülmez
784e0716fe Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-16 21:58:17 +02:00
Goekdeniz-Guelmez
df6d9e972f nits and adding it to test 2025-07-16 19:13:40 +02:00
Gökdeniz Gülmez
650c956fe6 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-16 16:29:10 +02:00
Gökdeniz Gülmez
d3d575cce7 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-04-21 20:27:33 +02:00
Gökdeniz Gülmez
8f2744dcf3 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-21 08:50:43 +01:00
Gökdeniz Gülmez
b12be4b7e0 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-12 16:52:21 +01:00
Gökdeniz Gülmez
ebfcb4a14f Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-10 17:10:50 +01:00
Gökdeniz Gülmez
79175a1f35 Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-07 11:41:19 +01:00
Gökdeniz Gülmez
59d4e4f61d Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-05 23:09:44 +01:00
Gökdeniz Gülmez
44f776921c Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-03-05 10:05:10 +01:00
Goekdeniz-Guelmez
871ee2b9b0 update ACKNOWLEDGMENTS.md 2025-02-28 23:24:39 +01:00
Goekdeniz-Guelmez
6c048ab4da initial commit with workong optmimizer 2025-02-28 23:16:51 +01:00
57 changed files with 333 additions and 1483 deletions

View File

@@ -272,7 +272,6 @@ jobs:
name: Build Python package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
@@ -334,7 +333,6 @@ jobs:
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
@@ -366,7 +364,7 @@ jobs:
type: string
default: ""
machine:
image: linux-cuda-12:2024.11.1
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout

View File

@@ -22,7 +22,7 @@ project(
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)

View File

@@ -42,9 +42,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -92,7 +90,7 @@ target_compile_options(
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"80"
"70;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -132,12 +130,3 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# Make Thunderkittens available
FetchContent_Declare(
kittens
GIT_REPOSITORY https://github.com/HazyResearch/ThunderKittens.git
GIT_TAG aaab847f430ed313ed466e64b25b9177babd1db8
GIT_SHALLOW TRUE)
FetchContent_MakeAvailable(kittens)
target_include_directories(mlx BEFORE PRIVATE "${kittens_SOURCE_DIR}/include")

View File

@@ -17,52 +17,6 @@ namespace cu {
constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
end_ = reinterpret_cast<void*>(
reinterpret_cast<char*>(buffer_) + small_pool_size);
next_free_ = reinterpret_cast<Block*>(buffer_);
auto num_blocks = small_pool_size / small_block_size;
auto curr = next_free_;
for (size_t i = 0; i < num_blocks - 1; ++i) {
curr->next = reinterpret_cast<Block*>(
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(buffer_));
}
void* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
next_free_ = next_free_->next;
return static_cast<void*>(b);
}
void SmallSizePool::free(void* p) {
auto b = static_cast<Block*>(p);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(void* p) {
return (p >= buffer_) && (p < end_);
}
CudaAllocator::CudaAllocator()
: buffer_cache_(
page_size,
@@ -82,9 +36,7 @@ Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
} else if (size < page_size) {
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
@@ -101,19 +53,11 @@ Buffer CudaAllocator::malloc(size_t size) {
lock.unlock();
buf = new CudaBuffer{nullptr, size};
// Try the scalar pool first
if (size <= small_block_size) {
buf->data = scalar_pool_.malloc();
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
if (!buf->data) {
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
}
lock.lock();
}
active_memory_ += size;
@@ -172,11 +116,7 @@ void CudaAllocator::cuda_free(void* buf) {
return;
}
}
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf);
}
cudaFree(buf);
}
size_t CudaAllocator::get_active_memory() const {

View File

@@ -22,28 +22,6 @@ struct CudaBuffer {
size_t size;
};
class SmallSizePool {
private:
struct Block {
Block* next;
};
void* buffer_{nullptr};
Block* next_free_{nullptr};
void* end_{nullptr};
public:
SmallSizePool();
~SmallSizePool();
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
void* malloc();
void free(void* p);
bool in_pool(void* p);
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
@@ -82,7 +60,6 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();

View File

@@ -166,7 +166,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
out.size(),

View File

@@ -219,7 +219,6 @@ void binary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
@@ -236,7 +235,6 @@ void binary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
@@ -271,7 +269,6 @@ void binary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),

View File

@@ -239,7 +239,6 @@ void binary_two_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
@@ -257,7 +256,6 @@ void binary_two_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
@@ -293,7 +291,6 @@ void binary_two_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),

View File

@@ -53,10 +53,9 @@ struct FusedKernelBuilder {
// Build function signature.
if (contiguous) {
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
os += "template <typename IdxT = uint32_t>\n";
} else {
os +=
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
os += "template <int NDIM, typename IdxT = uint32_t>\n";
}
os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) {
@@ -68,46 +67,12 @@ struct FusedKernelBuilder {
}
os += ") {\n";
// Index. For non contiguous kernels we create a separate index
// variable per variable otherwise everyone uses `index`.
// Index.
os +=
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
" IdxT index = cg::this_grid().thread_rank();\n"
" if (index >= size) {\n"
" return;\n"
" }\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " IdxT " + xname + "_idx = 0;\n";
}
os += " {\n";
os += " IdxT loc = index;\n";
os +=
" #pragma unroll\n"
" for (int i = NDIM - 1; i >= 0; i--) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
"_strides[i]);\n";
}
os +=
" loc /= shape[i];\n"
" }\n"
" }\n";
}
// Work loop
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -124,9 +89,12 @@ struct FusedKernelBuilder {
} else if (contiguous) {
value = fmt::format("{}[index]", xname);
} else {
value = fmt::format("{}[{}_idx]", xname, xname);
std::string index = fmt::format(
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
xname);
value = fmt::format("{}[{}]", xname, index);
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write tape.
@@ -145,30 +113,14 @@ struct FusedKernelBuilder {
}
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write output.
for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
}
}
os += " }\n";
os += "}\n";
}
};
@@ -204,28 +156,15 @@ void Compiled::eval_gpu(
builder.build("_strided", false);
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
std::vector<std::string> kernel_names = {
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
};
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
lib_name(),
i,
work_per_thread));
}
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
kernel_names.push_back(
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
});
@@ -268,21 +207,13 @@ void Compiled::eval_gpu(
args.append<uint32_t>(outputs[0].data_size());
}
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
// Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
if (contiguous) {
kernel_name +=
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
kernel_name += fmt::format("_contiguous<{}>", index_type);
} else {
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@@ -293,9 +224,8 @@ void Compiled::eval_gpu(
}
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
} // namespace mlx::core

View File

@@ -82,7 +82,6 @@ void copy_contiguous(
kernel,
num_blocks,
block_dims,
0,
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());

View File

@@ -79,7 +79,6 @@ void copy_general(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
data_size,
@@ -95,7 +94,6 @@ void copy_general(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
data_size,

View File

@@ -82,7 +82,6 @@ void copy_general_dynamic(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
@@ -100,7 +99,6 @@ void copy_general_dynamic(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),

View File

@@ -71,7 +71,6 @@ void copy_general_input(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
@@ -86,7 +85,6 @@ void copy_general_input(
kernel,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),

View File

@@ -66,6 +66,7 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
@@ -215,14 +216,12 @@ void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
@@ -233,7 +232,6 @@ void CommandEncoder::add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
@@ -244,7 +242,6 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
CUgraphNode node;
CHECK_CUDA_ERROR(
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));

View File

@@ -45,34 +45,25 @@ class CommandEncoder {
void set_output_array(const array& arr);
template <typename F, typename... Params>
void add_kernel_node(
F* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
Params&&... params) {
void
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
constexpr size_t num = sizeof...(Params);
void* ptrs[num];
size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)),
...);
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
}
void add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
void add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
void
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());

View File

@@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
} // namespace mlx::core

View File

@@ -52,29 +52,13 @@ const std::string& cuda_home() {
}
// Return the location of CCCL headers shipped with the distribution.
const std::string& cccl_dir() {
static std::string dir = []() {
std::filesystem::path path;
#if defined(MLX_CCCL_DIR)
// First search the install dir if defined.
path = MLX_CCCL_DIR;
if (std::filesystem::exists(path)) {
return path.string();
}
#endif
// Then search dynamically from the dir of libmlx.so file.
path = current_binary_dir().parent_path() / "include" / "cccl";
if (std::filesystem::exists(path)) {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
return std::string();
}();
return dir;
bool get_cccl_include(std::string* out) {
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
if (!std::filesystem::exists(cccl_headers)) {
return false;
}
*out = fmt::format("--include-path={}", cccl_headers.string());
return true;
}
// Get the cache directory for storing compiled results.
@@ -137,8 +121,7 @@ void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return;
}
@@ -151,9 +134,6 @@ void write_cached_ptx(
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
}
// Return if |device|'s version is not newer than |major|.|minor| version.
@@ -254,9 +234,8 @@ JitModule::JitModule(
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
std::string cccl_include;
if (get_cccl_include(&cccl_include)) {
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
@@ -293,8 +272,7 @@ JitModule::JitModule(
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
}
// Load module.

View File

@@ -237,7 +237,8 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -266,7 +267,6 @@ void LayerNorm::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
@@ -295,7 +295,9 @@ void LayerNormVJP::eval_gpu(
return x;
}
copied = true;
return contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();
@@ -379,7 +381,6 @@ void LayerNormVJP::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),

View File

@@ -108,7 +108,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
@@ -151,7 +152,6 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
axis_size);

View File

@@ -27,35 +27,6 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
}
}
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul {
public:
MatMul(
@@ -72,7 +43,7 @@ class MatMul {
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
: handle_(device.lt_handle()) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
@@ -106,6 +77,20 @@ class MatMul {
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
MatMul(
@@ -119,6 +104,7 @@ class MatMul {
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
bool c_transposed,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
@@ -140,15 +126,15 @@ class MatMul {
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
}
~MatMul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
cublasLtMatrixLayoutDestroy(a_desc_);
cublasLtMatrixLayoutDestroy(b_desc_);
cublasLtMatrixLayoutDestroy(c_desc_);
cublasLtMatrixLayoutDestroy(out_desc_);
cublasLtMatmulDescDestroy(matmul_desc_);
}
void run(
@@ -273,9 +259,9 @@ class MatMul {
return desc;
}
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
@@ -296,7 +282,8 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy = contiguous_copy_gpu(arr, s);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
enc.add_temporary(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
@@ -402,7 +389,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto c = inputs[2];
auto& c_pre = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -415,24 +404,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);
copy_gpu(c, out, CopyType::General, s);
c = out;
}
}
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
@@ -470,6 +442,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
K,
N,
ldb,
c_transposed,
ldc,
batch_shape.back(),
a_batch_strides.back(),

View File

@@ -1,108 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/matmul/tiles.cuh"
namespace mlx::core::cu {
template <typename U, typename T>
__device__ inline void
mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
/**
* Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16
* float tile.
*
* We actually perform C += A @ B.T
*/
__device__ inline void mma_t(
Tile16x16<float>& C,
Tile16x16<__nv_bfloat16>& A,
Tile16x16<__nv_bfloat16>& B) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
// D matrix
: "+f"(C.values[0].x),
"+f"(C.values[0].y),
"+f"(C.values[1].x),
"+f"(C.values[1].y)
// A matrix
: "r"(*(uint32_t*)(&A.values[0])),
"r"(*(uint32_t*)(&A.values[1])),
"r"(*(uint32_t*)(&A.values[2])),
"r"(*(uint32_t*)(&A.values[3])),
// B matrix
"r"(*(uint32_t*)(&B.values[0])),
"r"(*(uint32_t*)(&B.values[2])),
// C matrix
"f"(C.values[0].x),
"f"(C.values[0].y),
"f"(C.values[1].x),
"f"(C.values[1].y));
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};"
// D matrix
: "+f"(C.values[2].x),
"+f"(C.values[2].y),
"+f"(C.values[3].x),
"+f"(C.values[3].y)
// A matrix
: "r"(*(uint32_t*)(&A.values[0])),
"r"(*(uint32_t*)(&A.values[1])),
"r"(*(uint32_t*)(&A.values[2])),
"r"(*(uint32_t*)(&A.values[3])),
// B matrix
"r"(*(uint32_t*)(&B.values[1])),
"r"(*(uint32_t*)(&B.values[3])),
// C matrix
"f"(C.values[2].x),
"f"(C.values[2].y),
"f"(C.values[3].x),
"f"(C.values[3].y));
}
/**
* Multiply larger register tiles by delegating to mma_t.
*/
template <typename U, typename T, int M, int N, int K>
__device__ inline void mma_t(
RegisterTile<U, M, N>& C,
RegisterTile<T, M, K>& A,
RegisterTile<T, N, K>& B) {
constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;
constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;
constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;
MLX_UNROLL
for (int k = 0; k < TILES_K; k++) {
MLX_UNROLL
for (int m = 0; m < TILES_M; m++) {
MLX_UNROLL
for (int n = 0; n < TILES_N; n++) {
mma_t(
C.data[m * TILES_N + n],
A.data[m * TILES_K + k],
B.data[n * TILES_K + k]);
}
}
}
}
} // namespace mlx::core::cu

View File

@@ -1,419 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#define MLX_UNROLL _Pragma("unroll")
namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/**
* The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp.
*
* Each thread holds 8 values. They are distributed according to
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
*
* For use instructions see the individual methods eg load().
*/
template <typename T>
struct Tile16x16 {
using T2 = Vector2_t<T>;
T2 values[4];
__device__ inline void fill(T v) {
T2 v2 = {v, v};
for (int i = 0; i < 4; i++) {
values[i] = v2;
}
}
/**
* Load a 16x16 tile from shared memory.
*
* The instruction is a bit weird in the sense that the address provided by
* each thread and the elements loaded are not the same.
*
* We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a
* result the warp provides 4*8 = 32 addresses one per row.
*
* Threads 0-7 provide the addresses for the first tile, 8-15 for the second
* and so on. For instance to load a non swizzled tile we would do
*
* base_addr + (laneid % 16) * BK + (laneid / 2) * 8
*
* See
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
*/
__device__ inline void load(uint32_t row_address) {
if constexpr (
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(*(uint32_t*)&(values[0])),
"=r"(*(uint32_t*)&(values[1])),
"=r"(*(uint32_t*)&(values[2])),
"=r"(*(uint32_t*)&(values[3]))
: "r"(row_address));
}
}
/**
* Store the tile to the address pointed to by `x`.
*
* The provided pointer is a generic pointer but this is meant to be used to
* store to global memory. For storing to shared memory we should use
* `stmatrix`.
*
* This also showcases the format of the tile quite nicely. Each register is
* holding to adjacent values. The indices are
*
* row + 0, col + 0
* row + 8, col + 0
* row + 0, col + 8
* row + 8, col + 8
*
* Given that we are dealing with Vector2_t<U> the column offsets are 4
* instead of 8.
*/
template <typename U>
__device__ inline void store_global(U* x, int N) {
using U2 = Vector2_t<U>;
U2* x2 = reinterpret_cast<U2*>(x);
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
if constexpr (std::is_same_v<U2, T2>) {
x2[(row + 0) * (N / 2) + col + 0] = values[0];
x2[(row + 0) * (N / 2) + col + 4] = values[2];
x2[(row + 8) * (N / 2) + col + 0] = values[1];
x2[(row + 8) * (N / 2) + col + 4] = values[3];
} else if constexpr (
std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
x2[(row + 0) * (N / 2) + col + 0] =
__floats2bfloat162_rn(values[0].x, values[0].y);
x2[(row + 0) * (N / 2) + col + 4] =
__floats2bfloat162_rn(values[2].x, values[2].y);
x2[(row + 8) * (N / 2) + col + 0] =
__floats2bfloat162_rn(values[1].x, values[1].y);
x2[(row + 8) * (N / 2) + col + 4] =
__floats2bfloat162_rn(values[3].x, values[3].y);
}
}
template <typename U>
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
const int laneid = threadIdx.x % 32;
const int row = laneid / 4;
const int col = laneid % 4;
if (row < max_rows) {
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
}
if (row + 8 < max_rows) {
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
}
}
};
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
template <typename Tile>
__device__ inline void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
template <typename U>
__device__ inline void
store_global_safe(U* x, int N, int row, int col, int max_rows) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global_safe(
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
}
}
}
};
template <typename T, int ROWS_, int COLS_>
struct SharedTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
static constexpr int NUMEL = ROWS * COLS;
// Swizzle taken from ThunderKittens.
//
// See inludes/types/shared/st.cuh
//
// I do feel that it is too math heavy and can be improved. Also the math is
// done every time although the addresses don't change from load to load. I
// guess we are expecting the compiler to figure that out.
static constexpr int swizzle_bytes =
(sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))
: (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));
T data[ROWS * COLS];
// Return a pointer to the element at (row, col) using the swizzle.
__device__ static inline T* ptr(T* ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols;
const uint64_t addr =
(uint64_t)(&ptr
[outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols]);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (T*)(addr ^ swizzle);
} else {
return ptr + row * COLS + col;
}
}
// Return the location of the element at (row, col) using the swizzle.
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) {
static constexpr int swizzle_repeat = swizzle_bytes * 8;
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
const int outer_idx = col / subtile_cols;
const uint32_t addr = ptr +
sizeof(T) *
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
col % subtile_cols);
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
return (addr ^ swizzle);
} else {
return ptr + sizeof(T) * (row * COLS + col);
}
}
// Convenience functions to edit elements going through the swizzle.
__device__ inline T& operator()(int row, int col) {
return *ptr(data, row, col);
}
__device__ inline void store(float4& v, int row, int col) {
*(reinterpret_cast<float4*>(ptr(data, row, col))) = v;
}
__device__ inline void store(float2& v, int row, int col) {
*(reinterpret_cast<float2*>(ptr(data, row, col))) = v;
}
__device__ inline void store(float& v, int row, int col) {
*(reinterpret_cast<float*>(ptr(data, row, col))) = v;
}
template <int N>
__device__ inline void store(T (&v)[N], int row, int col) {
if constexpr (sizeof(T) * N == 4) {
store(*(reinterpret_cast<float*>(&v[0])), row, col);
} else if constexpr (sizeof(T) * N == 8) {
store(*(reinterpret_cast<float2*>(&v[0])), row, col);
} else if constexpr (sizeof(T) * N == 16) {
store(*(reinterpret_cast<float4*>(&v[0])), row, col);
} else {
MLX_UNROLL
for (int i = 0; i < N; i++) {
*ptr(data, row, col + i) = v[i];
}
}
}
};
/**
* Load the tile from global memory by loading 16 bytes at a time and storing
* them immediately.
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load(Tile& tile, const T* x, int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
float4 tmp;
tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
/**
* Copy 16 bytes from the globale memory address pointed to by x to the smem
* address pointed to by row_address.
*
* A simple wrapper over the PTX.
*/
template <typename T>
__device__ inline void cp_async_16(uint32_t row_address, const T* x) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int4*>(x)));
}
/**
* Submit all the previous async copies to be executed.
*/
__device__ inline void cp_async_commit() {
asm volatile("cp.async.commit_group;\n" ::);
}
/**
* Wait for all the async copies to finish.
*/
__device__ inline void cp_async_wait_all() {
asm volatile("cp.async.wait_all;\n" ::);
}
/**
* The asynchronous equivalent of load.
*
* Loads the tile from global memory by submitting a bunch of async copy
* instructions. The copy won't start until commit is called and we don't have
* a guarantee it will finish until wait is called.
*
* It should be used as follows
*
* load(...)
* load(...)
* cp_async_commit()
* do_other_stuff()
* cp_async_wait_all()
* do_stuff_with_shmem()
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void
load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
cp_async_16(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
}
}
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load_async_safe(
Tile& tile,
uint32_t base_address,
const T* x,
int N,
int max_rows) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
x += row * N + col * ELEMENTS_PER_LOAD;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
if (row + i * STEP_ROWS < max_rows) {
cp_async_16(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
} else {
float4 tmp = {0, 0, 0, 0};
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
}
} // namespace mlx::core::cu

View File

@@ -81,6 +81,7 @@ NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)

View File

@@ -2,17 +2,30 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int group_size, int bits>
__global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
@@ -227,102 +240,145 @@ __global__ void affine_dequantize(
}
} // namespace cu
namespace {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate the number of elements per thread
int per_thread = group_size_ / WARP_SIZE;
size_t size = w.size() / per_thread;
// Calculate the thread grid that we need to launch
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() /= per_thread;
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
enc.set_output_array(biases);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.size());
});
});
});
if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
constexpr int uint8_per_uint32 = 4;
int packs_per_int;
switch (bits_) {
} // namespace
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
packs_per_int = 8;
f(std::integral_constant<int, 5>{});
break;
case 6:
packs_per_int = 4;
f(std::integral_constant<int, 6>{});
break;
default:
packs_per_int = 8 / bits_;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto w = ensure_row_contiguous(w_pre, enc, s);
enc.set_input_array(w);
if (dequantize_) {
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(out);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
enc.set_output_array(out);
enc.set_output_array(scales);
enc.set_output_array(biases);
}
size_t size = w.size() / packs_per_int;
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
// Treat uint32 as uint8 in kernel
int uint8_per_uint32 = 4;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
size_t size =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
if (dequantize_) {
grid_shape.back() *= uint8_per_uint32;
} else {
grid_shape.back() /= per_thread;
}
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
w.size());
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) {
auto kernel =
cu::affine_dequantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<uint8_t>(),
inputs[1].data<DataType>(),
inputs[2].data<DataType>(),
out.data<DataType>(),
out.size());
} else {
auto kernel =
cu::affine_quantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<DataType>(),
out.data<uint8_t>(),
outputs[1].data<DataType>(),
outputs[2].data<DataType>(),
w.size());
}
});
});
});

View File

@@ -1,228 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/matmul/mma.cuh"
#include "mlx/backend/cuda/matmul/tiles.cuh"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h"
namespace mlx::core {
namespace cu {
template <int NUM_WARPS, int group_size, int bits, typename T, typename Tile>
__device__ inline void load_quantized(
Tile& tile,
const uint8_t* x,
const T* scales,
const T* biases,
int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(uint32_t) * get_pack_factor<bits>();
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
constexpr int MASK = (1 << bits) - 1;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
const int Nx = N / get_pack_factor<bits>();
const int Ng = N / group_size;
x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor<bits>());
scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
T vs[ELEMENTS_PER_LOAD];
uint32_t w = *reinterpret_cast<const uint32_t*>(x + i * STEP_ROWS * Nx);
T s = scales[i * STEP_ROWS * Ng];
T b = biases[i * STEP_ROWS * Ng];
MLX_UNROLL
for (int j = 0; j < ELEMENTS_PER_LOAD; j++) {
vs[j] = static_cast<T>((w >> (j * bits)) & MASK) * s + b;
}
tile.store(vs, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
template <
typename T,
int BM,
int BN,
int BK,
int group_size,
int bits,
bool aligned_M>
__global__ void qmm_t(
const T* x,
const uint8_t* w,
const T* scales,
const T* biases,
T* y,
int M,
int N,
int K) {
constexpr int WARPS_M = 2;
constexpr int WARPS_N = 4;
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
constexpr int WARP_STEP_M = BM / WARPS_M;
constexpr int WARP_STEP_N = BN / WARPS_N;
const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32;
const int wm = warpid / WARPS_N;
const int wn = warpid % WARPS_N;
const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N;
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&xs)[1] = *(SharedTile<T, BM, BK>(*)[1])(&shmem[0]);
SharedTile<T, BN, BK>(&ws)[1] =
*(SharedTile<T, BN, BK>(*)[1])(&shmem[1 * sizeof(T) * BM * BK]);
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
const int max_rows = M - blockIdx.y * BM;
x += blockIdx.y * BM * K;
w += blockIdx.x * BN * K / get_pack_factor<bits>();
scales += blockIdx.x * BN * K / group_size;
biases += blockIdx.x * BN * K / group_size;
y += blockIdx.y * BM * N + blockIdx.x * BN;
C.fill(0);
int tic = 0;
uint32_t base_addr_xs[1], base_addr_ws[1];
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
if (aligned_M || max_rows >= BM) {
for (int k_block = 0; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
cp_async_commit();
load_quantized<NUM_WARPS, group_size, bits>(
ws[tic],
w + k_block / get_pack_factor<bits>(),
scales + k_block / group_size,
biases + k_block / group_size,
K);
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
xs[tic],
base_addr_xs[tic],
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
ws[tic],
base_addr_ws[tic],
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
}
C.store_global(y, N, offset_m, offset_n);
} else {
for (int k_block = 0; k_block < K; k_block += BK) {
load_async_safe<NUM_WARPS>(
xs[tic], base_addr_xs[tic], x + k_block, K, max_rows);
cp_async_commit();
load_quantized<NUM_WARPS, group_size, bits>(
ws[tic],
w + k_block / get_pack_factor<bits>(),
scales + k_block / group_size,
biases + k_block / group_size,
K);
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
xs[tic],
base_addr_xs[tic],
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
ws[tic],
base_addr_ws[tic],
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
}
C.store_global_safe(y, N, offset_m, offset_n, max_rows);
}
}
} // namespace cu
void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
bool transpose_,
int group_size_,
int bits_,
int M,
int N,
int K,
cu::CommandEncoder& enc,
const Stream& s) {
if (x.dtype() != bfloat16) {
throw std::invalid_argument("[qmm] Only bfloat16 is supported for now");
}
if (!transpose_) {
throw std::invalid_argument(
"[qmm] Only transposed matmul is supported for now");
}
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
auto kernel =
cu::qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, true>;
if (M % BM != 0) {
kernel = cu::
qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, false>;
}
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
enc.add_kernel_node(
kernel,
grid,
2 * 4 * 32,
1 * sizeof(DataType) * (BM * BK + BN * BK),
x.data<DataType>(),
w.data<uint8_t>(),
scales.data<DataType>(),
biases.data<DataType>(),
out.data<DataType>(),
M,
N,
K);
});
});
});
}
} // namespace mlx::core

View File

@@ -1,113 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
}
} // namespace
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
out.set_data(allocator::malloc(out.nbytes()));
// Make sure the last two dims of x and w, s, b are contiguous. This should
// be relaxed for x.
array x = ensure_row_contiguous_matrix(inputs[0], enc, s);
array w = ensure_row_contiguous_matrix(inputs[1], enc, s);
array scales = ensure_row_contiguous_matrix(inputs[2], enc, s);
array biases = ensure_row_contiguous_matrix(inputs[3], enc, s);
// Extract the matmul shapes
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
int K = x.shape(-1);
int M = non_batched ? x.size() / K : x.shape(-2);
int N = out.shape(-1);
qmm(x,
w,
scales,
biases,
out,
transpose_,
group_size_,
bits_,
M,
N,
K,
enc,
s);
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}
} // namespace mlx::core

View File

@@ -1,42 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
bool transpose_,
int group_size_,
int bits_,
int M,
int N,
int K,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -1,59 +0,0 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core {
namespace cu {
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
} // namespace cu
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
} // namespace mlx::core

View File

@@ -170,7 +170,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbitsc,
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
@@ -181,7 +180,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbits,
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,

View File

@@ -47,7 +47,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy = contiguous_copy_gpu(in, s);
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);

View File

@@ -120,7 +120,6 @@ void all_reduce(
kernel,
blocks,
threads,
0,
static_cast<T*>(indata),
intermediate.data<U>(),
block_step,
@@ -147,7 +146,6 @@ void all_reduce(
kernel,
blocks,
threads,
0,
static_cast<T*>(indata),
out.data<U>(),
block_step,

View File

@@ -230,7 +230,7 @@ void col_reduce_looped(
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args);
kernel, grid, blocks, indata, out.data<U>(), args);
});
});
});

View File

@@ -41,8 +41,7 @@ void init_reduce(
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024;
encoder.add_kernel_node(
kernel, grid, block, 0, out.data<U>(), out.size());
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
});
});
}

View File

@@ -269,7 +269,7 @@ void row_reduce_simple(
int size = plan.shape.back();
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
kernel, grid, block, indata, out.data<U>(), out.size(), size);
});
});
}
@@ -322,7 +322,7 @@ void row_reduce_looped(
});
encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
kernel, grid, block, indata, out.data<U>(), out.size(), args);
});
});
}

View File

@@ -206,7 +206,8 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -232,7 +233,6 @@ void RMSNorm::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
out.data<DataType>(),
@@ -259,7 +259,9 @@ void RMSNormVJP::eval_gpu(
return x;
}
copied = true;
return contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable();
@@ -328,7 +330,6 @@ void RMSNormVJP::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),

View File

@@ -325,7 +325,6 @@ void RoPE::eval_gpu(
kernel,
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
@@ -342,7 +341,6 @@ void RoPE::eval_gpu(
kernel,
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
@@ -362,7 +360,6 @@ void RoPE::eval_gpu(
kernel,
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),
@@ -384,7 +381,6 @@ void RoPE::eval_gpu(
kernel,
grid,
block,
0,
(donated ? out : in).data<DataType>(),
out.data<DataType>(),
offset.data<int32_t>(),

View File

@@ -379,7 +379,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags());
}
} else {
in = contiguous_copy_gpu(in, s);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
out.copy_shared_buffer(in);
}
@@ -414,7 +416,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
in.data_size() / axis_size,
block_dim,
0,
in.data<T>(),
out.data<U>(),
axis_size);
@@ -444,7 +445,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
num_blocks,
block_dim,
0,
in.data<T>(),
out.data<U>(),
axis_size,

View File

@@ -125,7 +125,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -152,7 +153,6 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
axis_size);

View File

@@ -72,7 +72,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
if (!is_segmented_sort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s);
in = array(trans.shape(), trans.dtype(), nullptr, {});
copy_gpu(trans, in, CopyType::General, s);
encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);

View File

@@ -133,7 +133,6 @@ void ternary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
@@ -152,7 +151,6 @@ void ternary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
@@ -182,7 +180,6 @@ void ternary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),

View File

@@ -142,7 +142,6 @@ void unary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
@@ -155,7 +154,6 @@ void unary_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size(),

View File

@@ -46,10 +46,4 @@ void copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
array contiguous_copy_gpu(const array& arr, const Stream& s) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -43,7 +43,4 @@ void copy_gpu_inplace(
// Fill the output with the scalar val
void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
} // namespace mlx::core

View File

@@ -149,7 +149,8 @@ void explicit_gemm_conv_group_ND_gpu(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
// Materialize
array wt_transpose = contiguous_copy_gpu(wt_view, s);
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
@@ -960,12 +961,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
auto wt = inputs[1];
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
copies.push_back(in);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
copies.push_back(wt);
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
copy_gpu(wt, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
wt = arr_copy;
}
// 3D conv

View File

@@ -25,7 +25,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}

View File

@@ -33,7 +33,8 @@ std::tuple<bool, int64_t, array> check_transpose(
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy = contiguous_copy_gpu(arr, s);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
@@ -42,7 +43,8 @@ std::tuple<bool, int64_t, array> check_transpose(
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
@@ -73,7 +75,8 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
}
}
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
}
@@ -1891,7 +1894,8 @@ void segmented_mm(
return std::make_tuple(false, x);
}
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return std::make_tuple(true, x_copy);
};

View File

@@ -40,7 +40,8 @@ void RMSNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -106,7 +107,9 @@ void RMSNormVJP::eval_gpu(
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
@@ -238,7 +241,8 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -315,7 +319,8 @@ void LayerNormVJP::eval_gpu(
if (x.flags().row_contiguous) {
return {x, false};
}
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();

View File

@@ -20,7 +20,8 @@ namespace {
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
} else {
@@ -37,7 +38,8 @@ inline array ensure_row_contiguous_matrix(
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}

View File

@@ -989,7 +989,8 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// input for the axes with stride smaller than the minimum reduction
// stride.
if (plan.type == GeneralReduce) {
array in_copy = contiguous_copy_gpu(in, s);
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
d.add_temporary(in_copy, s.index);
in = in_copy;
plan = get_reduction_plan(in, axes_);

View File

@@ -398,7 +398,8 @@ void ScaledDotProductAttention::eval_gpu(
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy = contiguous_copy_gpu(arr, s);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {

View File

@@ -30,7 +30,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags());
}
} else {
in = contiguous_copy_gpu(in, s);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
out.copy_shared_buffer(in);
}

View File

@@ -35,7 +35,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 5
#define MLX_VERSION_PATCH 3
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -893,22 +893,24 @@ class Muon(Optimizer):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, X, steps: int):
assert (
X.ndim == 2
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
def _zeropower_via_newtonschulz5(self, G, steps: int):
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
transpose_needed = X.shape[-2] > X.shape[-1]
X = G.astype(mx.bfloat16)
transpose_needed = G.shape[-2] > G.shape[-1]
if transpose_needed:
X = X.T
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
# Ensure spectral norm is at most 1
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
X = X / norm
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
B = b * A + c * (A @ A)
X = a * X + B @ X
if transpose_needed:
X = X.T
@@ -917,35 +919,56 @@ class Muon(Optimizer):
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
# Apply weight decay
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
# Get effective gradient
if self.nesterov:
update = gradient * (1 - self.momentum) + v * self.momentum
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
else:
update = v
effective_grad = v
lr = self.learning_rate.astype(gradient.dtype)
if update.ndim >= 2:
original_shape = update.shape
reshape_needed = update.ndim > 2
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
if effective_grad.ndim < 2:
orthogonalized_grad = effective_grad
scale_factor = 1.0
else:
# Save original shape for 4D conv filters
original_shape = effective_grad.shape
reshape_needed = effective_grad.ndim > 2
if reshape_needed:
update = mx.reshape(update, (update.shape[0], -1))
effective_grad = mx.reshape(
effective_grad, (effective_grad.shape[0], -1)
)
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
# Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5(
effective_grad, steps=self.ns_steps
)
# Reshape back if needed
if reshape_needed:
update = mx.reshape(update, original_shape)
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
# Calculate scaling factor
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = (
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
)
return parameter - lr * update
return (
parameter
- self.learning_rate.astype(gradient.dtype)
* orthogonalized_grad
* scale_factor
)
def clip_grad_norm(grads, max_norm):

View File

@@ -1,7 +1,7 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--plat manylinux_2_39_x86_64 \
--exclude libcublas* \
--exclude libnvrtc* \
-w wheel_tmp

View File

@@ -691,21 +691,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self):
def make_ref_addmm(alpha, beta):
return lambda c, a, b: alpha * (a @ b) + beta * c

View File

@@ -39,14 +39,6 @@ target_sources(
linalg_tests.cpp
${METAL_TEST_SOURCES})
if(MLX_BUILD_CUDA)
# Find the CCCL headers in install dir.
target_compile_definitions(
mlx
PRIVATE
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl")
endif()
target_link_libraries(tests PRIVATE mlx doctest)
doctest_discover_tests(tests)
add_test(NAME tests COMMAND tests)