mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
15 Commits
b2273733ea
...
qmm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8269c9d02d | ||
|
|
903b40627c | ||
|
|
700f7dcf01 | ||
|
|
6c60bd1cbf | ||
|
|
a64cc02a0c | ||
|
|
346ae5fdb5 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 | ||
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b |
@@ -272,6 +272,7 @@ jobs:
|
|||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
|
python setup.py clean --all
|
||||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||||
- when:
|
- when:
|
||||||
condition:
|
condition:
|
||||||
@@ -333,6 +334,7 @@ jobs:
|
|||||||
<< parameters.build_env >> pip install ".[dev]" -v
|
<< parameters.build_env >> pip install ".[dev]" -v
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||||
bash python/scripts/repair_linux.sh
|
bash python/scripts/repair_linux.sh
|
||||||
- when:
|
- when:
|
||||||
@@ -364,7 +366,7 @@ jobs:
|
|||||||
type: string
|
type: string
|
||||||
default: ""
|
default: ""
|
||||||
machine:
|
machine:
|
||||||
image: linux-cuda-12:default
|
image: linux-cuda-12:2024.11.1
|
||||||
resource_class: gpu.nvidia.small.gen2
|
resource_class: gpu.nvidia.small.gen2
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ project(
|
|||||||
|
|
||||||
# ----------------------------- Setup -----------------------------
|
# ----------------------------- Setup -----------------------------
|
||||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||||
|
|||||||
@@ -19,3 +19,4 @@ Common Optimizers
|
|||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
MultiOptimizer
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
@@ -90,7 +92,7 @@ target_compile_options(
|
|||||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||||
set(MLX_CUDA_ARCHITECTURES
|
set(MLX_CUDA_ARCHITECTURES
|
||||||
"70;80"
|
"80"
|
||||||
CACHE STRING "CUDA architectures")
|
CACHE STRING "CUDA architectures")
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
@@ -130,3 +132,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
|||||||
# Install CCCL headers for JIT.
|
# Install CCCL headers for JIT.
|
||||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||||
|
|
||||||
|
# Make Thunderkittens available
|
||||||
|
FetchContent_Declare(
|
||||||
|
kittens
|
||||||
|
GIT_REPOSITORY https://github.com/HazyResearch/ThunderKittens.git
|
||||||
|
GIT_TAG aaab847f430ed313ed466e64b25b9177babd1db8
|
||||||
|
GIT_SHALLOW TRUE)
|
||||||
|
FetchContent_MakeAvailable(kittens)
|
||||||
|
target_include_directories(mlx BEFORE PRIVATE "${kittens_SOURCE_DIR}/include")
|
||||||
|
|||||||
@@ -17,6 +17,52 @@ namespace cu {
|
|||||||
|
|
||||||
constexpr int page_size = 16384;
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
|
// Any allocations smaller than this will try to use the small pool
|
||||||
|
constexpr int small_block_size = 8;
|
||||||
|
|
||||||
|
// The small pool size in bytes. This should be a multiple of the host page
|
||||||
|
// size and small_block_size.
|
||||||
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
|
|
||||||
|
SmallSizePool::SmallSizePool() {
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
|
||||||
|
end_ = reinterpret_cast<void*>(
|
||||||
|
reinterpret_cast<char*>(buffer_) + small_pool_size);
|
||||||
|
next_free_ = reinterpret_cast<Block*>(buffer_);
|
||||||
|
|
||||||
|
auto num_blocks = small_pool_size / small_block_size;
|
||||||
|
auto curr = next_free_;
|
||||||
|
for (size_t i = 0; i < num_blocks - 1; ++i) {
|
||||||
|
curr->next = reinterpret_cast<Block*>(
|
||||||
|
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
|
||||||
|
curr = curr->next;
|
||||||
|
}
|
||||||
|
curr->next = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallSizePool::~SmallSizePool() {
|
||||||
|
CHECK_CUDA_ERROR(cudaFree(buffer_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* SmallSizePool::malloc() {
|
||||||
|
if (next_free_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Block* b = next_free_;
|
||||||
|
next_free_ = next_free_->next;
|
||||||
|
return static_cast<void*>(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SmallSizePool::free(void* p) {
|
||||||
|
auto b = static_cast<Block*>(p);
|
||||||
|
b->next = next_free_;
|
||||||
|
next_free_ = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SmallSizePool::in_pool(void* p) {
|
||||||
|
return (p >= buffer_) && (p < end_);
|
||||||
|
}
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
page_size,
|
page_size,
|
||||||
@@ -36,7 +82,9 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
auto orig_size = size;
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
if (size < page_size) {
|
if (size <= small_block_size) {
|
||||||
|
size = 8;
|
||||||
|
} else if (size < page_size) {
|
||||||
size = next_power_of_2(size);
|
size = next_power_of_2(size);
|
||||||
} else {
|
} else {
|
||||||
size = page_size * ((size + page_size - 1) / page_size);
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
@@ -53,11 +101,19 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
buf = new CudaBuffer{nullptr, size};
|
buf = new CudaBuffer{nullptr, size};
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
// Try the scalar pool first
|
||||||
throw std::runtime_error(fmt::format(
|
if (size <= small_block_size) {
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
buf->data = scalar_pool_.malloc();
|
||||||
}
|
}
|
||||||
|
if (!buf->data) {
|
||||||
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
active_memory_ += size;
|
active_memory_ += size;
|
||||||
@@ -116,7 +172,11 @@ void CudaAllocator::cuda_free(void* buf) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cudaFree(buf);
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
|
scalar_pool_.free(buf);
|
||||||
|
} else {
|
||||||
|
cudaFree(buf);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::get_active_memory() const {
|
size_t CudaAllocator::get_active_memory() const {
|
||||||
|
|||||||
@@ -22,6 +22,28 @@ struct CudaBuffer {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SmallSizePool {
|
||||||
|
private:
|
||||||
|
struct Block {
|
||||||
|
Block* next;
|
||||||
|
};
|
||||||
|
|
||||||
|
void* buffer_{nullptr};
|
||||||
|
Block* next_free_{nullptr};
|
||||||
|
void* end_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
SmallSizePool();
|
||||||
|
~SmallSizePool();
|
||||||
|
|
||||||
|
SmallSizePool(const SmallSizePool&) = delete;
|
||||||
|
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||||
|
|
||||||
|
void* malloc();
|
||||||
|
void free(void* p);
|
||||||
|
bool in_pool(void* p);
|
||||||
|
};
|
||||||
|
|
||||||
class CudaAllocator : public allocator::Allocator {
|
class CudaAllocator : public allocator::Allocator {
|
||||||
public:
|
public:
|
||||||
Buffer malloc(size_t size) override;
|
Buffer malloc(size_t size) override;
|
||||||
@@ -60,6 +82,7 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
BufferCache<CudaBuffer> buffer_cache_;
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
|
SmallSizePool scalar_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaAllocator& allocator();
|
CudaAllocator& allocator();
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<uint32_t>(),
|
out.data<uint32_t>(),
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
@@ -219,6 +219,7 @@ void binary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
@@ -235,6 +236,7 @@ void binary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
@@ -269,6 +271,7 @@ void binary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
@@ -256,6 +257,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
@@ -291,6 +293,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
|
|||||||
@@ -53,9 +53,10 @@ struct FusedKernelBuilder {
|
|||||||
|
|
||||||
// Build function signature.
|
// Build function signature.
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
os += "template <typename IdxT = uint32_t>\n";
|
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
} else {
|
} else {
|
||||||
os += "template <int NDIM, typename IdxT = uint32_t>\n";
|
os +=
|
||||||
|
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
}
|
}
|
||||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||||
for (size_t i = 0; i < params.size(); ++i) {
|
for (size_t i = 0; i < params.size(); ++i) {
|
||||||
@@ -67,12 +68,46 @@ struct FusedKernelBuilder {
|
|||||||
}
|
}
|
||||||
os += ") {\n";
|
os += ") {\n";
|
||||||
|
|
||||||
// Index.
|
// Index. For non contiguous kernels we create a separate index
|
||||||
|
// variable per variable otherwise everyone uses `index`.
|
||||||
os +=
|
os +=
|
||||||
" IdxT index = cg::this_grid().thread_rank();\n"
|
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
|
||||||
" if (index >= size) {\n"
|
" if (index >= size) {\n"
|
||||||
" return;\n"
|
" return;\n"
|
||||||
" }\n";
|
" }\n";
|
||||||
|
if (!contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " IdxT " + xname + "_idx = 0;\n";
|
||||||
|
}
|
||||||
|
os += " {\n";
|
||||||
|
os += " IdxT loc = index;\n";
|
||||||
|
os +=
|
||||||
|
" #pragma unroll\n"
|
||||||
|
" for (int i = NDIM - 1; i >= 0; i--) {\n";
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
|
||||||
|
"_strides[i]);\n";
|
||||||
|
}
|
||||||
|
os +=
|
||||||
|
" loc /= shape[i];\n"
|
||||||
|
" }\n"
|
||||||
|
" }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work loop
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||||
|
|
||||||
// Read inputs.
|
// Read inputs.
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
@@ -89,12 +124,9 @@ struct FusedKernelBuilder {
|
|||||||
} else if (contiguous) {
|
} else if (contiguous) {
|
||||||
value = fmt::format("{}[index]", xname);
|
value = fmt::format("{}[index]", xname);
|
||||||
} else {
|
} else {
|
||||||
std::string index = fmt::format(
|
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||||
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
|
|
||||||
xname);
|
|
||||||
value = fmt::format("{}[{}]", xname, index);
|
|
||||||
}
|
}
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write tape.
|
// Write tape.
|
||||||
@@ -113,14 +145,30 @@ struct FusedKernelBuilder {
|
|||||||
}
|
}
|
||||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||||
}
|
}
|
||||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write output.
|
// Write output.
|
||||||
for (const auto& x : outputs) {
|
for (const auto& x : outputs) {
|
||||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// End of work loop
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" index++;\n";
|
||||||
|
if (!contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os += " }\n";
|
||||||
|
|
||||||
os += "}\n";
|
os += "}\n";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -156,15 +204,28 @@ void Compiled::eval_gpu(
|
|||||||
builder.build("_strided", false);
|
builder.build("_strided", false);
|
||||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
// Build kernel names.
|
// Build kernel names.
|
||||||
std::vector<std::string> kernel_names = {
|
std::vector<std::string> kernel_names;
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
|
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
|
||||||
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
|
|
||||||
};
|
|
||||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
|
||||||
kernel_names.push_back(fmt::format(
|
kernel_names.push_back(fmt::format(
|
||||||
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
|
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||||
kernel_names.push_back(
|
lib_name(),
|
||||||
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
|
work_per_thread));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
work_per_thread));
|
||||||
|
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
i,
|
||||||
|
work_per_thread));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
i,
|
||||||
|
work_per_thread));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||||
});
|
});
|
||||||
@@ -207,13 +268,21 @@ void Compiled::eval_gpu(
|
|||||||
args.append<uint32_t>(outputs[0].data_size());
|
args.append<uint32_t>(outputs[0].data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Choose work per thread
|
||||||
|
int work_per_thread = 4;
|
||||||
|
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||||
|
work_per_thread = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Launch kernel.
|
// Launch kernel.
|
||||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
kernel_name += fmt::format("_contiguous<{}>", index_type);
|
kernel_name +=
|
||||||
|
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||||
} else {
|
} else {
|
||||||
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
|
kernel_name += fmt::format(
|
||||||
|
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||||
}
|
}
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
for (const auto& in : inputs) {
|
for (const auto& in : inputs) {
|
||||||
@@ -224,8 +293,9 @@ void Compiled::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
|
auto [num_blocks, block_dims] =
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
||||||
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ void copy_contiguous(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>() + in_offset,
|
in.data<InType>() + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
out.data<OutType>() + out_offset,
|
||||||
out.data_size());
|
out.data_size());
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ void copy_general(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
data_size,
|
||||||
@@ -94,6 +95,7 @@ void copy_general(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
data_size,
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ void copy_general_dynamic(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
@@ -99,6 +100,7 @@ void copy_general_dynamic(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ void copy_general_input(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
@@ -85,6 +86,7 @@ void copy_general_input(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0));
|
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
@@ -216,12 +215,14 @@ void CommandEncoder::add_kernel_node(
|
|||||||
void* func,
|
void* func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
cudaKernelNodeParams kernel_params = {0};
|
cudaKernelNodeParams kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
kernel_params.blockDim = block_dim;
|
kernel_params.blockDim = block_dim;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
|
kernel_params.sharedMemBytes = smem_bytes;
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||||
@@ -232,6 +233,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
CUfunction func,
|
CUfunction func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
@@ -242,6 +244,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.blockDimY = block_dim.y;
|
kernel_params.blockDimY = block_dim.y;
|
||||||
kernel_params.blockDimZ = block_dim.z;
|
kernel_params.blockDimZ = block_dim.z;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
|
kernel_params.sharedMemBytes = smem_bytes;
|
||||||
CUgraphNode node;
|
CUgraphNode node;
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||||
|
|||||||
@@ -45,25 +45,34 @@ class CommandEncoder {
|
|||||||
void set_output_array(const array& arr);
|
void set_output_array(const array& arr);
|
||||||
|
|
||||||
template <typename F, typename... Params>
|
template <typename F, typename... Params>
|
||||||
void
|
void add_kernel_node(
|
||||||
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
|
F* func,
|
||||||
|
dim3 grid_dim,
|
||||||
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
|
Params&&... params) {
|
||||||
constexpr size_t num = sizeof...(Params);
|
constexpr size_t num = sizeof...(Params);
|
||||||
void* ptrs[num];
|
void* ptrs[num];
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
||||||
std::forward<Params>(params)),
|
std::forward<Params>(params)),
|
||||||
...);
|
...);
|
||||||
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
|
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_kernel_node(
|
void add_kernel_node(
|
||||||
CUfunction func,
|
CUfunction func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params);
|
void** params);
|
||||||
|
|
||||||
void
|
void add_kernel_node(
|
||||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
void* func,
|
||||||
|
dim3 grid_dim,
|
||||||
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
|
void** params);
|
||||||
|
|
||||||
void add_temporary(const array& arr) {
|
void add_temporary(const array& arr) {
|
||||||
temporaries_.push_back(arr.data_shared_ptr());
|
temporaries_.push_back(arr.data_shared_ptr());
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
|
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -52,13 +52,29 @@ const std::string& cuda_home() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of CCCL headers shipped with the distribution.
|
// Return the location of CCCL headers shipped with the distribution.
|
||||||
bool get_cccl_include(std::string* out) {
|
const std::string& cccl_dir() {
|
||||||
auto cccl_headers = current_binary_dir().parent_path() / "include" / "cccl";
|
static std::string dir = []() {
|
||||||
if (!std::filesystem::exists(cccl_headers)) {
|
std::filesystem::path path;
|
||||||
return false;
|
#if defined(MLX_CCCL_DIR)
|
||||||
}
|
// First search the install dir if defined.
|
||||||
*out = fmt::format("--include-path={}", cccl_headers.string());
|
path = MLX_CCCL_DIR;
|
||||||
return true;
|
if (std::filesystem::exists(path)) {
|
||||||
|
return path.string();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
// Then search dynamically from the dir of libmlx.so file.
|
||||||
|
path = current_binary_dir().parent_path() / "include" / "cccl";
|
||||||
|
if (std::filesystem::exists(path)) {
|
||||||
|
return path.string();
|
||||||
|
}
|
||||||
|
// Finally check the environment variable.
|
||||||
|
path = std::getenv("MLX_CCCL_DIR");
|
||||||
|
if (!path.empty() && std::filesystem::exists(path)) {
|
||||||
|
return path.string();
|
||||||
|
}
|
||||||
|
return std::string();
|
||||||
|
}();
|
||||||
|
return dir;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the cache directory for storing compiled results.
|
// Get the cache directory for storing compiled results.
|
||||||
@@ -121,7 +137,8 @@ void write_cached_ptx(
|
|||||||
const std::filesystem::path& cache_dir,
|
const std::filesystem::path& cache_dir,
|
||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||||
|
const std::string& source_code) {
|
||||||
if (cache_dir.empty()) {
|
if (cache_dir.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -134,6 +151,9 @@ void write_cached_ptx(
|
|||||||
for (const auto& [name, mangled] : ptx_kernels) {
|
for (const auto& [name, mangled] : ptx_kernels) {
|
||||||
txt_file << name << "\t" << mangled << std::endl;
|
txt_file << name << "\t" << mangled << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::ofstream source_file(cache_dir / (module_name + ".cu"));
|
||||||
|
source_file << source_code;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return if |device|'s version is not newer than |major|.|minor| version.
|
// Return if |device|'s version is not newer than |major|.|minor| version.
|
||||||
@@ -234,8 +254,9 @@ JitModule::JitModule(
|
|||||||
device.compute_capability_major(),
|
device.compute_capability_major(),
|
||||||
device.compute_capability_minor());
|
device.compute_capability_minor());
|
||||||
args.push_back(compute.c_str());
|
args.push_back(compute.c_str());
|
||||||
std::string cccl_include;
|
std::string cccl_include = cccl_dir();
|
||||||
if (get_cccl_include(&cccl_include)) {
|
if (!cccl_include.empty()) {
|
||||||
|
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||||
args.push_back(cccl_include.c_str());
|
args.push_back(cccl_include.c_str());
|
||||||
}
|
}
|
||||||
std::string cuda_include =
|
std::string cuda_include =
|
||||||
@@ -272,7 +293,8 @@ JitModule::JitModule(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
write_cached_ptx(
|
||||||
|
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load module.
|
// Load module.
|
||||||
|
|||||||
@@ -237,8 +237,7 @@ void LayerNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@@ -267,6 +266,7 @@ void LayerNorm::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
b.data<DataType>(),
|
b.data<DataType>(),
|
||||||
@@ -295,9 +295,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
copied = true;
|
copied = true;
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
return contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
return x_copy;
|
|
||||||
};
|
};
|
||||||
bool donate_x = inputs[0].is_donatable();
|
bool donate_x = inputs[0].is_donatable();
|
||||||
bool donate_g = inputs[3].is_donatable();
|
bool donate_g = inputs[3].is_donatable();
|
||||||
@@ -381,6 +379,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
g.data<DataType>(),
|
g.data<DataType>(),
|
||||||
|
|||||||
@@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(x_copy);
|
encoder.add_temporary(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@@ -152,6 +151,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<DataType>(),
|
in.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
|
|||||||
@@ -27,6 +27,35 @@ void check_cublas_error(const char* name, cublasStatus_t err) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct CublasPreference {
|
||||||
|
CublasPreference(Device& device) {
|
||||||
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
|
// for Hopper+:
|
||||||
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
|
uint64_t MiB = 1024 * 1024;
|
||||||
|
uint64_t workspace_size =
|
||||||
|
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||||
|
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||||
|
pref_,
|
||||||
|
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||||
|
&workspace_size,
|
||||||
|
sizeof(uint64_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
~CublasPreference() {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||||
|
static CublasPreference pref(device);
|
||||||
|
return pref.pref_;
|
||||||
|
}
|
||||||
|
|
||||||
class MatMul {
|
class MatMul {
|
||||||
public:
|
public:
|
||||||
MatMul(
|
MatMul(
|
||||||
@@ -43,7 +72,7 @@ class MatMul {
|
|||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride)
|
int64_t b_batch_stride)
|
||||||
: handle_(device.lt_handle()) {
|
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
|
||||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
auto scale_type = dtype_to_cuda_type(dtype);
|
auto scale_type = dtype_to_cuda_type(dtype);
|
||||||
@@ -77,20 +106,6 @@ class MatMul {
|
|||||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||||
out_desc_ = create_matrix_layout(
|
out_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
|
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
|
||||||
// for Hopper+:
|
|
||||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
|
||||||
uint64_t MiB = 1024 * 1024;
|
|
||||||
uint64_t workspace_size =
|
|
||||||
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
|
||||||
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
|
||||||
pref_,
|
|
||||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
|
||||||
&workspace_size,
|
|
||||||
sizeof(uint64_t)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMul(
|
MatMul(
|
||||||
@@ -104,7 +119,6 @@ class MatMul {
|
|||||||
uint64_t b_rows,
|
uint64_t b_rows,
|
||||||
uint64_t b_cols,
|
uint64_t b_cols,
|
||||||
int64_t ldb,
|
int64_t ldb,
|
||||||
bool c_transposed,
|
|
||||||
int64_t ldc,
|
int64_t ldc,
|
||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
@@ -126,15 +140,15 @@ class MatMul {
|
|||||||
b_batch_stride) {
|
b_batch_stride) {
|
||||||
auto type = dtype_to_cuda_type(dtype);
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
c_desc_ = create_matrix_layout(
|
c_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
~MatMul() {
|
~MatMul() {
|
||||||
cublasLtMatrixLayoutDestroy(a_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(b_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(c_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||||
cublasLtMatrixLayoutDestroy(out_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||||
cublasLtMatmulDescDestroy(matmul_desc_);
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
@@ -259,9 +273,9 @@ class MatMul {
|
|||||||
return desc;
|
return desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
cublasLtHandle_t handle_{nullptr};
|
cublasLtHandle_t handle_{nullptr};
|
||||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||||
cublasLtMatmulPreference_t pref_{nullptr};
|
|
||||||
cublasLtMatrixLayout_t a_desc_{nullptr};
|
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||||
cublasLtMatrixLayout_t b_desc_{nullptr};
|
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||||
cublasLtMatrixLayout_t c_desc_{nullptr};
|
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||||
@@ -282,8 +296,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
|||||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
enc.add_temporary(arr_copy);
|
enc.add_temporary(arr_copy);
|
||||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||||
}
|
}
|
||||||
@@ -389,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 3);
|
assert(inputs.size() == 3);
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
auto& c_pre = inputs[2];
|
auto c = inputs[2];
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Init checks and prep
|
// Init checks and prep
|
||||||
@@ -404,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// the arrays
|
// the arrays
|
||||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||||
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
|
|
||||||
|
int64_t ldc;
|
||||||
|
{
|
||||||
|
auto stx = c.strides()[c.ndim() - 2];
|
||||||
|
auto sty = c.strides()[c.ndim() - 1];
|
||||||
|
if (sty == 1 && stx == c.shape(-1)) {
|
||||||
|
ldc = stx;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
} else if (sty == 1 && stx == 0) {
|
||||||
|
ldc = 0;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
} else {
|
||||||
|
// Copy C into out and set C to out
|
||||||
|
ldc = c.shape(-1);
|
||||||
|
copy_gpu(c, out, CopyType::General, s);
|
||||||
|
c = out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Check and collapse batch dimensions
|
// Check and collapse batch dimensions
|
||||||
@@ -442,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
K,
|
K,
|
||||||
N,
|
N,
|
||||||
ldb,
|
ldb,
|
||||||
c_transposed,
|
|
||||||
ldc,
|
ldc,
|
||||||
batch_shape.back(),
|
batch_shape.back(),
|
||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
|
|||||||
108
mlx/backend/cuda/matmul/mma.cuh
Normal file
108
mlx/backend/cuda/matmul/mma.cuh
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/matmul/tiles.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
template <typename U, typename T>
|
||||||
|
__device__ inline void
|
||||||
|
mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16
|
||||||
|
* float tile.
|
||||||
|
*
|
||||||
|
* We actually perform C += A @ B.T
|
||||||
|
*/
|
||||||
|
__device__ inline void mma_t(
|
||||||
|
Tile16x16<float>& C,
|
||||||
|
Tile16x16<__nv_bfloat16>& A,
|
||||||
|
Tile16x16<__nv_bfloat16>& B) {
|
||||||
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||||
|
"{%0, %1, %2, %3}, "
|
||||||
|
"{%4, %5, %6, %7}, "
|
||||||
|
"{%8, %9}, "
|
||||||
|
"{%10, %11, %12, %13};"
|
||||||
|
|
||||||
|
// D matrix
|
||||||
|
: "+f"(C.values[0].x),
|
||||||
|
"+f"(C.values[0].y),
|
||||||
|
"+f"(C.values[1].x),
|
||||||
|
"+f"(C.values[1].y)
|
||||||
|
|
||||||
|
// A matrix
|
||||||
|
: "r"(*(uint32_t*)(&A.values[0])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[1])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[2])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[3])),
|
||||||
|
|
||||||
|
// B matrix
|
||||||
|
"r"(*(uint32_t*)(&B.values[0])),
|
||||||
|
"r"(*(uint32_t*)(&B.values[2])),
|
||||||
|
|
||||||
|
// C matrix
|
||||||
|
"f"(C.values[0].x),
|
||||||
|
"f"(C.values[0].y),
|
||||||
|
"f"(C.values[1].x),
|
||||||
|
"f"(C.values[1].y));
|
||||||
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||||
|
"{%0, %1, %2, %3}, "
|
||||||
|
"{%4, %5, %6, %7}, "
|
||||||
|
"{%8, %9}, "
|
||||||
|
"{%10, %11, %12, %13};"
|
||||||
|
|
||||||
|
// D matrix
|
||||||
|
: "+f"(C.values[2].x),
|
||||||
|
"+f"(C.values[2].y),
|
||||||
|
"+f"(C.values[3].x),
|
||||||
|
"+f"(C.values[3].y)
|
||||||
|
|
||||||
|
// A matrix
|
||||||
|
: "r"(*(uint32_t*)(&A.values[0])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[1])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[2])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[3])),
|
||||||
|
|
||||||
|
// B matrix
|
||||||
|
"r"(*(uint32_t*)(&B.values[1])),
|
||||||
|
"r"(*(uint32_t*)(&B.values[3])),
|
||||||
|
|
||||||
|
// C matrix
|
||||||
|
"f"(C.values[2].x),
|
||||||
|
"f"(C.values[2].y),
|
||||||
|
"f"(C.values[3].x),
|
||||||
|
"f"(C.values[3].y));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiply larger register tiles by delegating to mma_t.
|
||||||
|
*/
|
||||||
|
template <typename U, typename T, int M, int N, int K>
|
||||||
|
__device__ inline void mma_t(
|
||||||
|
RegisterTile<U, M, N>& C,
|
||||||
|
RegisterTile<T, M, K>& A,
|
||||||
|
RegisterTile<T, N, K>& B) {
|
||||||
|
constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;
|
||||||
|
constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;
|
||||||
|
constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < TILES_K; k++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int m = 0; m < TILES_M; m++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int n = 0; n < TILES_N; n++) {
|
||||||
|
mma_t(
|
||||||
|
C.data[m * TILES_N + n],
|
||||||
|
A.data[m * TILES_K + k],
|
||||||
|
B.data[n * TILES_K + k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
419
mlx/backend/cuda/matmul/tiles.cuh
Normal file
419
mlx/backend/cuda/matmul/tiles.cuh
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#define MLX_UNROLL _Pragma("unroll")
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// Map types to their vector of 2 type float -> float2, double -> double2 etc
|
||||||
|
template <typename T>
|
||||||
|
struct Vector2;
|
||||||
|
template <>
|
||||||
|
struct Vector2<double> {
|
||||||
|
using type = double2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<float> {
|
||||||
|
using type = float2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<__half> {
|
||||||
|
using type = __half2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<__nv_bfloat16> {
|
||||||
|
using type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
using Vector2_t = typename Vector2<T>::type;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
||||||
|
* the warp.
|
||||||
|
*
|
||||||
|
* Each thread holds 8 values. They are distributed according to
|
||||||
|
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
||||||
|
*
|
||||||
|
* For use instructions see the individual methods eg load().
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
struct Tile16x16 {
|
||||||
|
using T2 = Vector2_t<T>;
|
||||||
|
|
||||||
|
T2 values[4];
|
||||||
|
|
||||||
|
__device__ inline void fill(T v) {
|
||||||
|
T2 v2 = {v, v};
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
values[i] = v2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a 16x16 tile from shared memory.
|
||||||
|
*
|
||||||
|
* The instruction is a bit weird in the sense that the address provided by
|
||||||
|
* each thread and the elements loaded are not the same.
|
||||||
|
*
|
||||||
|
* We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a
|
||||||
|
* result the warp provides 4*8 = 32 addresses one per row.
|
||||||
|
*
|
||||||
|
* Threads 0-7 provide the addresses for the first tile, 8-15 for the second
|
||||||
|
* and so on. For instance to load a non swizzled tile we would do
|
||||||
|
*
|
||||||
|
* base_addr + (laneid % 16) * BK + (laneid / 2) * 8
|
||||||
|
*
|
||||||
|
* See
|
||||||
|
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
|
||||||
|
*/
|
||||||
|
__device__ inline void load(uint32_t row_address) {
|
||||||
|
if constexpr (
|
||||||
|
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
|
||||||
|
asm volatile(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(*(uint32_t*)&(values[0])),
|
||||||
|
"=r"(*(uint32_t*)&(values[1])),
|
||||||
|
"=r"(*(uint32_t*)&(values[2])),
|
||||||
|
"=r"(*(uint32_t*)&(values[3]))
|
||||||
|
: "r"(row_address));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Store the tile to the address pointed to by `x`.
|
||||||
|
*
|
||||||
|
* The provided pointer is a generic pointer but this is meant to be used to
|
||||||
|
* store to global memory. For storing to shared memory we should use
|
||||||
|
* `stmatrix`.
|
||||||
|
*
|
||||||
|
* This also showcases the format of the tile quite nicely. Each register is
|
||||||
|
* holding to adjacent values. The indices are
|
||||||
|
*
|
||||||
|
* row + 0, col + 0
|
||||||
|
* row + 8, col + 0
|
||||||
|
* row + 0, col + 8
|
||||||
|
* row + 8, col + 8
|
||||||
|
*
|
||||||
|
* Given that we are dealing with Vector2_t<U> the column offsets are 4
|
||||||
|
* instead of 8.
|
||||||
|
*/
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global(U* x, int N) {
|
||||||
|
using U2 = Vector2_t<U>;
|
||||||
|
U2* x2 = reinterpret_cast<U2*>(x);
|
||||||
|
const int laneid = threadIdx.x % 32;
|
||||||
|
const int row = laneid / 4;
|
||||||
|
const int col = laneid % 4;
|
||||||
|
if constexpr (std::is_same_v<U2, T2>) {
|
||||||
|
x2[(row + 0) * (N / 2) + col + 0] = values[0];
|
||||||
|
x2[(row + 0) * (N / 2) + col + 4] = values[2];
|
||||||
|
x2[(row + 8) * (N / 2) + col + 0] = values[1];
|
||||||
|
x2[(row + 8) * (N / 2) + col + 4] = values[3];
|
||||||
|
} else if constexpr (
|
||||||
|
std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
|
||||||
|
x2[(row + 0) * (N / 2) + col + 0] =
|
||||||
|
__floats2bfloat162_rn(values[0].x, values[0].y);
|
||||||
|
x2[(row + 0) * (N / 2) + col + 4] =
|
||||||
|
__floats2bfloat162_rn(values[2].x, values[2].y);
|
||||||
|
x2[(row + 8) * (N / 2) + col + 0] =
|
||||||
|
__floats2bfloat162_rn(values[1].x, values[1].y);
|
||||||
|
x2[(row + 8) * (N / 2) + col + 4] =
|
||||||
|
__floats2bfloat162_rn(values[3].x, values[3].y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
|
||||||
|
const int laneid = threadIdx.x % 32;
|
||||||
|
const int row = laneid / 4;
|
||||||
|
const int col = laneid % 4;
|
||||||
|
if (row < max_rows) {
|
||||||
|
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
|
||||||
|
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
|
||||||
|
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
|
||||||
|
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
|
||||||
|
}
|
||||||
|
if (row + 8 < max_rows) {
|
||||||
|
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
|
||||||
|
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
|
||||||
|
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
|
||||||
|
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple container of multiple Tile16x16.
|
||||||
|
*
|
||||||
|
* Provides utility functions for loading and manipulating collections of basic
|
||||||
|
* tiles.
|
||||||
|
*/
|
||||||
|
template <typename T, int ROWS_, int COLS_>
|
||||||
|
struct RegisterTile {
|
||||||
|
static constexpr int ROWS = ROWS_;
|
||||||
|
static constexpr int COLS = COLS_;
|
||||||
|
static constexpr int TILES_X = COLS / 16;
|
||||||
|
static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
|
||||||
|
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||||
|
|
||||||
|
__device__ inline void fill(T v) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].fill(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tile>
|
||||||
|
__device__ inline void
|
||||||
|
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].load(
|
||||||
|
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].store_global(
|
||||||
|
x + (row + i * 16) * N + col + j * 16, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void
|
||||||
|
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].store_global_safe(
|
||||||
|
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int ROWS_, int COLS_>
|
||||||
|
struct SharedTile {
|
||||||
|
static constexpr int ROWS = ROWS_;
|
||||||
|
static constexpr int COLS = COLS_;
|
||||||
|
static constexpr int TILES_X = COLS / 16;
|
||||||
|
static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
static constexpr int NUMEL = ROWS * COLS;
|
||||||
|
|
||||||
|
// Swizzle taken from ThunderKittens.
|
||||||
|
//
|
||||||
|
// See inludes/types/shared/st.cuh
|
||||||
|
//
|
||||||
|
// I do feel that it is too math heavy and can be improved. Also the math is
|
||||||
|
// done every time although the addresses don't change from load to load. I
|
||||||
|
// guess we are expecting the compiler to figure that out.
|
||||||
|
static constexpr int swizzle_bytes =
|
||||||
|
(sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))
|
||||||
|
: (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));
|
||||||
|
|
||||||
|
T data[ROWS * COLS];
|
||||||
|
|
||||||
|
// Return a pointer to the element at (row, col) using the swizzle.
|
||||||
|
__device__ static inline T* ptr(T* ptr, int row, int col) {
|
||||||
|
if constexpr (swizzle_bytes > 0) {
|
||||||
|
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||||
|
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||||
|
const int outer_idx = col / subtile_cols;
|
||||||
|
const uint64_t addr =
|
||||||
|
(uint64_t)(&ptr
|
||||||
|
[outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||||
|
col % subtile_cols]);
|
||||||
|
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||||
|
return (T*)(addr ^ swizzle);
|
||||||
|
} else {
|
||||||
|
return ptr + row * COLS + col;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the location of the element at (row, col) using the swizzle.
|
||||||
|
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||||
|
if constexpr (swizzle_bytes > 0) {
|
||||||
|
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||||
|
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||||
|
const int outer_idx = col / subtile_cols;
|
||||||
|
const uint32_t addr = ptr +
|
||||||
|
sizeof(T) *
|
||||||
|
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||||
|
col % subtile_cols);
|
||||||
|
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||||
|
return (addr ^ swizzle);
|
||||||
|
} else {
|
||||||
|
return ptr + sizeof(T) * (row * COLS + col);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience functions to edit elements going through the swizzle.
|
||||||
|
__device__ inline T& operator()(int row, int col) {
|
||||||
|
return *ptr(data, row, col);
|
||||||
|
}
|
||||||
|
__device__ inline void store(float4& v, int row, int col) {
|
||||||
|
*(reinterpret_cast<float4*>(ptr(data, row, col))) = v;
|
||||||
|
}
|
||||||
|
__device__ inline void store(float2& v, int row, int col) {
|
||||||
|
*(reinterpret_cast<float2*>(ptr(data, row, col))) = v;
|
||||||
|
}
|
||||||
|
__device__ inline void store(float& v, int row, int col) {
|
||||||
|
*(reinterpret_cast<float*>(ptr(data, row, col))) = v;
|
||||||
|
}
|
||||||
|
template <int N>
|
||||||
|
__device__ inline void store(T (&v)[N], int row, int col) {
|
||||||
|
if constexpr (sizeof(T) * N == 4) {
|
||||||
|
store(*(reinterpret_cast<float*>(&v[0])), row, col);
|
||||||
|
} else if constexpr (sizeof(T) * N == 8) {
|
||||||
|
store(*(reinterpret_cast<float2*>(&v[0])), row, col);
|
||||||
|
} else if constexpr (sizeof(T) * N == 16) {
|
||||||
|
store(*(reinterpret_cast<float4*>(&v[0])), row, col);
|
||||||
|
} else {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
*ptr(data, row, col + i) = v[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load the tile from global memory by loading 16 bytes at a time and storing
|
||||||
|
* them immediately.
|
||||||
|
*/
|
||||||
|
template <int NUM_WARPS, typename T, typename Tile>
|
||||||
|
__device__ inline void load(Tile& tile, const T* x, int N) {
|
||||||
|
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||||
|
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
float4 tmp;
|
||||||
|
tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
|
||||||
|
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy 16 bytes from the globale memory address pointed to by x to the smem
|
||||||
|
* address pointed to by row_address.
|
||||||
|
*
|
||||||
|
* A simple wrapper over the PTX.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
__device__ inline void cp_async_16(uint32_t row_address, const T* x) {
|
||||||
|
asm volatile(
|
||||||
|
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||||
|
"l"(reinterpret_cast<const int4*>(x)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Submit all the previous async copies to be executed.
|
||||||
|
*/
|
||||||
|
__device__ inline void cp_async_commit() {
|
||||||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wait for all the async copies to finish.
|
||||||
|
*/
|
||||||
|
__device__ inline void cp_async_wait_all() {
|
||||||
|
asm volatile("cp.async.wait_all;\n" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The asynchronous equivalent of load.
|
||||||
|
*
|
||||||
|
* Loads the tile from global memory by submitting a bunch of async copy
|
||||||
|
* instructions. The copy won't start until commit is called and we don't have
|
||||||
|
* a guarantee it will finish until wait is called.
|
||||||
|
*
|
||||||
|
* It should be used as follows
|
||||||
|
*
|
||||||
|
* load(...)
|
||||||
|
* load(...)
|
||||||
|
* cp_async_commit()
|
||||||
|
* do_other_stuff()
|
||||||
|
* cp_async_wait_all()
|
||||||
|
* do_stuff_with_shmem()
|
||||||
|
*/
|
||||||
|
template <int NUM_WARPS, typename T, typename Tile>
|
||||||
|
__device__ inline void
|
||||||
|
load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
|
||||||
|
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||||
|
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
cp_async_16(
|
||||||
|
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
||||||
|
x + i * STEP_ROWS * N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NUM_WARPS, typename T, typename Tile>
|
||||||
|
__device__ inline void load_async_safe(
|
||||||
|
Tile& tile,
|
||||||
|
uint32_t base_address,
|
||||||
|
const T* x,
|
||||||
|
int N,
|
||||||
|
int max_rows) {
|
||||||
|
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||||
|
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
if (row + i * STEP_ROWS < max_rows) {
|
||||||
|
cp_async_16(
|
||||||
|
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
||||||
|
x + i * STEP_ROWS * N);
|
||||||
|
} else {
|
||||||
|
float4 tmp = {0, 0, 0, 0};
|
||||||
|
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -81,7 +81,6 @@ NO_GPU(Hadamard)
|
|||||||
NO_GPU(Load)
|
NO_GPU(Load)
|
||||||
NO_GPU_MULTI(LUF)
|
NO_GPU_MULTI(LUF)
|
||||||
NO_GPU_MULTI(QRF)
|
NO_GPU_MULTI(QRF)
|
||||||
NO_GPU(QuantizedMatmul)
|
|
||||||
NO_GPU(SegmentedMM)
|
NO_GPU(SegmentedMM)
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
|
|||||||
@@ -2,30 +2,17 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <int bits, int wsize = 8>
|
|
||||||
inline constexpr __device__ short get_pack_factor() {
|
|
||||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int bits, int wsize = 8>
|
|
||||||
inline constexpr __device__ short get_bytes_per_pack() {
|
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
||||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
__global__ void
|
__global__ void
|
||||||
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||||
@@ -240,145 +227,102 @@ __global__ void affine_dequantize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
namespace {
|
|
||||||
|
|
||||||
inline array ensure_row_contiguous(
|
void affine_quantize(
|
||||||
const array& x,
|
const array& w,
|
||||||
|
array& wq,
|
||||||
|
array& scales,
|
||||||
|
array& biases,
|
||||||
|
int group_size_,
|
||||||
|
int bits_,
|
||||||
cu::CommandEncoder& enc,
|
cu::CommandEncoder& enc,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
// Calculate the number of elements per thread
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
int per_thread = group_size_ / WARP_SIZE;
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
size_t size = w.size() / per_thread;
|
||||||
enc.add_temporary(x_copy);
|
|
||||||
return x_copy;
|
|
||||||
} else {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void dispatch_groups(int group_size, F&& f) {
|
|
||||||
switch (group_size) {
|
|
||||||
case 32:
|
|
||||||
f(std::integral_constant<int, 32>{});
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
f(std::integral_constant<int, 64>{});
|
|
||||||
break;
|
|
||||||
case 128:
|
|
||||||
f(std::integral_constant<int, 128>{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void dispatch_bits(int bits, F&& f) {
|
|
||||||
switch (bits) {
|
|
||||||
case 2:
|
|
||||||
f(std::integral_constant<int, 2>{});
|
|
||||||
break;
|
|
||||||
case 3:
|
|
||||||
f(std::integral_constant<int, 3>{});
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
f(std::integral_constant<int, 4>{});
|
|
||||||
break;
|
|
||||||
case 5:
|
|
||||||
f(std::integral_constant<int, 5>{});
|
|
||||||
break;
|
|
||||||
case 6:
|
|
||||||
f(std::integral_constant<int, 6>{});
|
|
||||||
break;
|
|
||||||
case 8:
|
|
||||||
f(std::integral_constant<int, 8>{});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void fast::AffineQuantize::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
auto& w_pre = inputs[0];
|
|
||||||
auto& out = outputs[0];
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
auto& s = stream();
|
|
||||||
auto& d = cu::device(s.device);
|
|
||||||
auto& enc = d.get_command_encoder(s);
|
|
||||||
|
|
||||||
auto w = ensure_row_contiguous(w_pre, enc, s);
|
|
||||||
enc.set_input_array(w);
|
|
||||||
if (dequantize_) {
|
|
||||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
|
||||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
|
||||||
enc.set_input_array(scales);
|
|
||||||
enc.set_input_array(biases);
|
|
||||||
enc.set_output_array(out);
|
|
||||||
} else {
|
|
||||||
auto& scales = outputs[1];
|
|
||||||
auto& biases = outputs[2];
|
|
||||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
|
||||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
|
||||||
enc.set_output_array(out);
|
|
||||||
enc.set_output_array(scales);
|
|
||||||
enc.set_output_array(biases);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
|
|
||||||
|
|
||||||
// Treat uint32 as uint8 in kernel
|
|
||||||
int uint8_per_uint32 = 4;
|
|
||||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
|
||||||
: bits_ == 6 ? 4
|
|
||||||
: 8 / bits_;
|
|
||||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
|
|
||||||
size_t size =
|
|
||||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
|
||||||
|
|
||||||
|
// Calculate the thread grid that we need to launch
|
||||||
bool large = size > UINT_MAX;
|
bool large = size > UINT_MAX;
|
||||||
auto grid_shape = w.shape();
|
auto grid_shape = w.shape();
|
||||||
|
grid_shape.back() /= per_thread;
|
||||||
|
|
||||||
if (dequantize_) {
|
enc.set_input_array(w);
|
||||||
grid_shape.back() *= uint8_per_uint32;
|
enc.set_output_array(wq);
|
||||||
} else {
|
enc.set_output_array(scales);
|
||||||
grid_shape.back() /= per_thread;
|
enc.set_output_array(biases);
|
||||||
}
|
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||||
|
|
||||||
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
|
|
||||||
dispatch_groups(group_size_, [&](auto group_size) {
|
dispatch_groups(group_size_, [&](auto group_size) {
|
||||||
dispatch_bits(bits_, [&](auto bits) {
|
dispatch_bits(bits_, [&](auto bits) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
if (dequantize_) {
|
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
|
||||||
auto kernel =
|
auto [num_blocks, block_dims] =
|
||||||
cu::affine_dequantize<DataType, group_size.value, bits.value>;
|
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||||
auto [num_blocks, block_dims] =
|
enc.add_kernel_node(
|
||||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
kernel,
|
||||||
enc.add_kernel_node(
|
num_blocks,
|
||||||
kernel,
|
block_dims,
|
||||||
num_blocks,
|
0,
|
||||||
block_dims,
|
w.data<T>(),
|
||||||
w.data<uint8_t>(),
|
wq.data<uint8_t>(),
|
||||||
inputs[1].data<DataType>(),
|
scales.data<T>(),
|
||||||
inputs[2].data<DataType>(),
|
biases.data<T>(),
|
||||||
out.data<DataType>(),
|
w.size());
|
||||||
out.size());
|
});
|
||||||
} else {
|
});
|
||||||
auto kernel =
|
});
|
||||||
cu::affine_quantize<DataType, group_size.value, bits.value>;
|
}
|
||||||
auto [num_blocks, block_dims] =
|
|
||||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
void affine_dequantize(
|
||||||
enc.add_kernel_node(
|
const array& wq,
|
||||||
kernel,
|
const array& scales,
|
||||||
num_blocks,
|
const array& biases,
|
||||||
block_dims,
|
array& w,
|
||||||
w.data<DataType>(),
|
int group_size_,
|
||||||
out.data<uint8_t>(),
|
int bits_,
|
||||||
outputs[1].data<DataType>(),
|
cu::CommandEncoder& enc,
|
||||||
outputs[2].data<DataType>(),
|
const Stream& s) {
|
||||||
w.size());
|
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
|
||||||
}
|
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
|
||||||
|
constexpr int uint8_per_uint32 = 4;
|
||||||
|
int packs_per_int;
|
||||||
|
switch (bits_) {
|
||||||
|
case 3:
|
||||||
|
case 5:
|
||||||
|
packs_per_int = 8;
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
packs_per_int = 4;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
packs_per_int = 8 / bits_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size = w.size() / packs_per_int;
|
||||||
|
bool large = size > UINT_MAX;
|
||||||
|
auto grid_shape = w.shape();
|
||||||
|
grid_shape.back() *= uint8_per_uint32;
|
||||||
|
|
||||||
|
enc.set_input_array(wq);
|
||||||
|
enc.set_input_array(scales);
|
||||||
|
enc.set_input_array(biases);
|
||||||
|
enc.set_output_array(w);
|
||||||
|
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||||
|
dispatch_groups(group_size_, [&](auto group_size) {
|
||||||
|
dispatch_bits(bits_, [&](auto bits) {
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||||
|
enc.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
wq.data<uint8_t>(),
|
||||||
|
scales.data<T>(),
|
||||||
|
biases.data<T>(),
|
||||||
|
w.data<T>(),
|
||||||
|
w.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
228
mlx/backend/cuda/quantized/qmm.cu
Normal file
228
mlx/backend/cuda/quantized/qmm.cu
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/matmul/mma.cuh"
|
||||||
|
#include "mlx/backend/cuda/matmul/tiles.cuh"
|
||||||
|
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <int NUM_WARPS, int group_size, int bits, typename T, typename Tile>
|
||||||
|
__device__ inline void load_quantized(
|
||||||
|
Tile& tile,
|
||||||
|
const uint8_t* x,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
int N) {
|
||||||
|
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
constexpr int ELEMENTS_PER_LOAD = sizeof(uint32_t) * get_pack_factor<bits>();
|
||||||
|
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
constexpr int MASK = (1 << bits) - 1;
|
||||||
|
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const int Nx = N / get_pack_factor<bits>();
|
||||||
|
const int Ng = N / group_size;
|
||||||
|
|
||||||
|
x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor<bits>());
|
||||||
|
scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
|
||||||
|
biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
T vs[ELEMENTS_PER_LOAD];
|
||||||
|
uint32_t w = *reinterpret_cast<const uint32_t*>(x + i * STEP_ROWS * Nx);
|
||||||
|
T s = scales[i * STEP_ROWS * Ng];
|
||||||
|
T b = biases[i * STEP_ROWS * Ng];
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < ELEMENTS_PER_LOAD; j++) {
|
||||||
|
vs[j] = static_cast<T>((w >> (j * bits)) & MASK) * s + b;
|
||||||
|
}
|
||||||
|
tile.store(vs, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool aligned_M>
|
||||||
|
__global__ void qmm_t(
|
||||||
|
const T* x,
|
||||||
|
const uint8_t* w,
|
||||||
|
const T* scales,
|
||||||
|
const T* biases,
|
||||||
|
T* y,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr int WARPS_M = 2;
|
||||||
|
constexpr int WARPS_N = 4;
|
||||||
|
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||||
|
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||||
|
constexpr int WARP_STEP_N = BN / WARPS_N;
|
||||||
|
|
||||||
|
const int warpid = threadIdx.x / 32;
|
||||||
|
const int laneid = threadIdx.x % 32;
|
||||||
|
const int wm = warpid / WARPS_N;
|
||||||
|
const int wn = warpid % WARPS_N;
|
||||||
|
const int offset_m = wm * WARP_STEP_M;
|
||||||
|
const int offset_n = wn * WARP_STEP_N;
|
||||||
|
|
||||||
|
extern __shared__ char shmem[];
|
||||||
|
SharedTile<T, BM, BK>(&xs)[1] = *(SharedTile<T, BM, BK>(*)[1])(&shmem[0]);
|
||||||
|
SharedTile<T, BN, BK>(&ws)[1] =
|
||||||
|
*(SharedTile<T, BN, BK>(*)[1])(&shmem[1 * sizeof(T) * BM * BK]);
|
||||||
|
|
||||||
|
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||||
|
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||||
|
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||||
|
|
||||||
|
const int max_rows = M - blockIdx.y * BM;
|
||||||
|
|
||||||
|
x += blockIdx.y * BM * K;
|
||||||
|
w += blockIdx.x * BN * K / get_pack_factor<bits>();
|
||||||
|
scales += blockIdx.x * BN * K / group_size;
|
||||||
|
biases += blockIdx.x * BN * K / group_size;
|
||||||
|
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||||
|
|
||||||
|
C.fill(0);
|
||||||
|
|
||||||
|
int tic = 0;
|
||||||
|
uint32_t base_addr_xs[1], base_addr_ws[1];
|
||||||
|
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
|
||||||
|
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
|
||||||
|
|
||||||
|
if (aligned_M || max_rows >= BM) {
|
||||||
|
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||||
|
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
|
||||||
|
cp_async_commit();
|
||||||
|
load_quantized<NUM_WARPS, group_size, bits>(
|
||||||
|
ws[tic],
|
||||||
|
w + k_block / get_pack_factor<bits>(),
|
||||||
|
scales + k_block / group_size,
|
||||||
|
biases + k_block / group_size,
|
||||||
|
K);
|
||||||
|
cp_async_wait_all();
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < BK / 16; k++) {
|
||||||
|
A.load(
|
||||||
|
xs[tic],
|
||||||
|
base_addr_xs[tic],
|
||||||
|
offset_m + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
B.load(
|
||||||
|
ws[tic],
|
||||||
|
base_addr_ws[tic],
|
||||||
|
offset_n + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
mma_t(C, A, B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
C.store_global(y, N, offset_m, offset_n);
|
||||||
|
} else {
|
||||||
|
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||||
|
load_async_safe<NUM_WARPS>(
|
||||||
|
xs[tic], base_addr_xs[tic], x + k_block, K, max_rows);
|
||||||
|
cp_async_commit();
|
||||||
|
load_quantized<NUM_WARPS, group_size, bits>(
|
||||||
|
ws[tic],
|
||||||
|
w + k_block / get_pack_factor<bits>(),
|
||||||
|
scales + k_block / group_size,
|
||||||
|
biases + k_block / group_size,
|
||||||
|
K);
|
||||||
|
cp_async_wait_all();
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < BK / 16; k++) {
|
||||||
|
A.load(
|
||||||
|
xs[tic],
|
||||||
|
base_addr_xs[tic],
|
||||||
|
offset_m + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
B.load(
|
||||||
|
ws[tic],
|
||||||
|
base_addr_ws[tic],
|
||||||
|
offset_n + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
mma_t(C, A, B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
C.store_global_safe(y, N, offset_m, offset_n, max_rows);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void qmm(
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
array& out,
|
||||||
|
bool transpose_,
|
||||||
|
int group_size_,
|
||||||
|
int bits_,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
|
if (x.dtype() != bfloat16) {
|
||||||
|
throw std::invalid_argument("[qmm] Only bfloat16 is supported for now");
|
||||||
|
}
|
||||||
|
if (!transpose_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[qmm] Only transposed matmul is supported for now");
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
|
||||||
|
dispatch_groups(group_size_, [&](auto group_size) {
|
||||||
|
dispatch_bits(bits_, [&](auto bits) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
|
constexpr int BM = 128;
|
||||||
|
constexpr int BN = 128;
|
||||||
|
constexpr int BK = 32;
|
||||||
|
auto kernel =
|
||||||
|
cu::qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, true>;
|
||||||
|
if (M % BM != 0) {
|
||||||
|
kernel = cu::
|
||||||
|
qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, false>;
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
|
||||||
|
|
||||||
|
enc.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
2 * 4 * 32,
|
||||||
|
1 * sizeof(DataType) * (BM * BK + BN * BK),
|
||||||
|
x.data<DataType>(),
|
||||||
|
w.data<uint8_t>(),
|
||||||
|
scales.data<DataType>(),
|
||||||
|
biases.data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
113
mlx/backend/cuda/quantized/quantized.cu
Normal file
113
mlx/backend/cuda/quantized/quantized.cu
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/quantized/quantized.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline array ensure_row_contiguous(
|
||||||
|
const array& x,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
|
if (!x.flags().row_contiguous) {
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
enc.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array ensure_row_contiguous_matrix(
|
||||||
|
const array& x,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s) {
|
||||||
|
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||||
|
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||||
|
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
enc.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = cu::device(s.device);
|
||||||
|
auto& enc = d.get_command_encoder(s);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
// Make sure the last two dims of x and w, s, b are contiguous. This should
|
||||||
|
// be relaxed for x.
|
||||||
|
array x = ensure_row_contiguous_matrix(inputs[0], enc, s);
|
||||||
|
array w = ensure_row_contiguous_matrix(inputs[1], enc, s);
|
||||||
|
array scales = ensure_row_contiguous_matrix(inputs[2], enc, s);
|
||||||
|
array biases = ensure_row_contiguous_matrix(inputs[3], enc, s);
|
||||||
|
|
||||||
|
// Extract the matmul shapes
|
||||||
|
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int M = non_batched ? x.size() / K : x.shape(-2);
|
||||||
|
int N = out.shape(-1);
|
||||||
|
|
||||||
|
qmm(x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
out,
|
||||||
|
transpose_,
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
enc,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fast::AffineQuantize::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = cu::device(s.device);
|
||||||
|
auto& enc = d.get_command_encoder(s);
|
||||||
|
|
||||||
|
if (dequantize_) {
|
||||||
|
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||||
|
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||||
|
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||||
|
auto& w = outputs[0];
|
||||||
|
|
||||||
|
w.set_data(allocator::malloc(w.nbytes()));
|
||||||
|
|
||||||
|
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||||
|
} else {
|
||||||
|
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||||
|
auto& wq = outputs[0];
|
||||||
|
auto& scales = outputs[1];
|
||||||
|
auto& biases = outputs[2];
|
||||||
|
|
||||||
|
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||||
|
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||||
|
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||||
|
|
||||||
|
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
42
mlx/backend/cuda/quantized/quantized.cuh
Normal file
42
mlx/backend/cuda/quantized/quantized.cuh
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void affine_quantize(
|
||||||
|
const array& w,
|
||||||
|
array& wq,
|
||||||
|
array& scales,
|
||||||
|
array& biases,
|
||||||
|
int group_size_,
|
||||||
|
int bits_,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
void affine_dequantize(
|
||||||
|
const array& wq,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
array& w,
|
||||||
|
int group_size_,
|
||||||
|
int bits_,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
void qmm(
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
array& out,
|
||||||
|
bool transpose_,
|
||||||
|
int group_size_,
|
||||||
|
int bits_,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
cu::CommandEncoder& enc,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <int bits, int wsize = 8>
|
||||||
|
inline constexpr __device__ short get_pack_factor() {
|
||||||
|
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int bits, int wsize = 8>
|
||||||
|
inline constexpr __device__ short get_bytes_per_pack() {
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_groups(int group_size, F&& f) {
|
||||||
|
switch (group_size) {
|
||||||
|
case 32:
|
||||||
|
f(std::integral_constant<int, 32>{});
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
f(std::integral_constant<int, 64>{});
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
f(std::integral_constant<int, 128>{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_bits(int bits, F&& f) {
|
||||||
|
switch (bits) {
|
||||||
|
case 2:
|
||||||
|
f(std::integral_constant<int, 2>{});
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
f(std::integral_constant<int, 3>{});
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
f(std::integral_constant<int, 4>{});
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
f(std::integral_constant<int, 5>{});
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
f(std::integral_constant<int, 6>{});
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
f(std::integral_constant<int, 8>{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
cu::rbitsc,
|
cu::rbitsc,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
keys.data<uint32_t>(),
|
keys.data<uint32_t>(),
|
||||||
out.data<uint8_t>(),
|
out.data<uint8_t>(),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
@@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
cu::rbits,
|
cu::rbits,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
keys.data<uint32_t>(),
|
keys.data<uint32_t>(),
|
||||||
out.data<uint8_t>(),
|
out.data<uint8_t>(),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
|
|||||||
@@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(in_copy);
|
encoder.add_temporary(in_copy);
|
||||||
in = in_copy;
|
in = in_copy;
|
||||||
plan = get_reduction_plan(in, axes_);
|
plan = get_reduction_plan(in, axes_);
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ void all_reduce(
|
|||||||
kernel,
|
kernel,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
|
0,
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
intermediate.data<U>(),
|
intermediate.data<U>(),
|
||||||
block_step,
|
block_step,
|
||||||
@@ -146,6 +147,7 @@ void all_reduce(
|
|||||||
kernel,
|
kernel,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
|
0,
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
block_step,
|
block_step,
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ void col_reduce_looped(
|
|||||||
auto kernel =
|
auto kernel =
|
||||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, blocks, indata, out.data<U>(), args);
|
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ void init_reduce(
|
|||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||||
grid.x = (grid.x + 1023) / 1024;
|
grid.x = (grid.x + 1023) / 1024;
|
||||||
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
|
encoder.add_kernel_node(
|
||||||
|
kernel, grid, block, 0, out.data<U>(), out.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -269,7 +269,7 @@ void row_reduce_simple(
|
|||||||
|
|
||||||
int size = plan.shape.back();
|
int size = plan.shape.back();
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, indata, out.data<U>(), out.size(), size);
|
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -322,7 +322,7 @@ void row_reduce_looped(
|
|||||||
});
|
});
|
||||||
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, indata, out.data<U>(), out.size(), args);
|
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@@ -233,6 +232,7 @@ void RMSNorm::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
@@ -259,9 +259,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
copied = true;
|
copied = true;
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
return contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
return x_copy;
|
|
||||||
};
|
};
|
||||||
bool donate_x = inputs[0].is_donatable();
|
bool donate_x = inputs[0].is_donatable();
|
||||||
bool donate_g = inputs[2].is_donatable();
|
bool donate_g = inputs[2].is_donatable();
|
||||||
@@ -330,6 +328,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
g.data<DataType>(),
|
g.data<DataType>(),
|
||||||
|
|||||||
@@ -325,6 +325,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@@ -341,6 +342,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@@ -360,6 +362,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@@ -381,6 +384,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
|
|||||||
@@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
|
||||||
in = std::move(arr_copy);
|
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -416,6 +414,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
in.data_size() / axis_size,
|
in.data_size() / axis_size,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
@@ -445,6 +444,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
axis_size,
|
axis_size,
|
||||||
|
|||||||
@@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@@ -153,6 +152,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<DataType>(),
|
in.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
|
|||||||
@@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
||||||
if (!is_segmented_sort) {
|
if (!is_segmented_sort) {
|
||||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||||
in = array(trans.shape(), trans.dtype(), nullptr, {});
|
in = contiguous_copy_gpu(trans, s);
|
||||||
copy_gpu(trans, in, CopyType::General, s);
|
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||||
encoder.add_temporary(out);
|
encoder.add_temporary(out);
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ void ternary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
@@ -151,6 +152,7 @@ void ternary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
@@ -180,6 +182,7 @@ void ternary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
|
|||||||
@@ -142,6 +142,7 @@ void unary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>(),
|
in.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size());
|
out.data_size());
|
||||||
@@ -154,6 +155,7 @@ void unary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>(),
|
in.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size(),
|
out.data_size(),
|
||||||
|
|||||||
@@ -46,4 +46,10 @@ void copy_gpu_inplace(
|
|||||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array contiguous_copy_gpu(const array& arr, const Stream& s) {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -43,4 +43,7 @@ void copy_gpu_inplace(
|
|||||||
// Fill the output with the scalar val
|
// Fill the output with the scalar val
|
||||||
void fill_gpu(const array& val, array& out, const Stream& s);
|
void fill_gpu(const array& val, array& out, const Stream& s);
|
||||||
|
|
||||||
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
|
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||||
|
|
||||||
// Materialize
|
// Materialize
|
||||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
array wt_transpose = contiguous_copy_gpu(wt_view, s);
|
||||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
|
||||||
|
|
||||||
// Perform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||||
@@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
auto wt = inputs[1];
|
auto wt = inputs[1];
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
copies.push_back(in);
|
||||||
copies.push_back(arr_copy);
|
|
||||||
in = arr_copy;
|
|
||||||
}
|
}
|
||||||
if (!wt.flags().row_contiguous) {
|
if (!wt.flags().row_contiguous) {
|
||||||
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
copy_gpu(wt, arr_copy, CopyType::General, s);
|
copies.push_back(wt);
|
||||||
copies.push_back(arr_copy);
|
|
||||||
wt = arr_copy;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3D conv
|
// 3D conv
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,8 +33,7 @@ std::tuple<bool, int64_t, array> check_transpose(
|
|||||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
copies.push_back(arr_copy);
|
copies.push_back(arr_copy);
|
||||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||||
}
|
}
|
||||||
@@ -43,8 +42,7 @@ std::tuple<bool, int64_t, array> check_transpose(
|
|||||||
inline array
|
inline array
|
||||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
if (!x.flags().row_contiguous) {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
} else {
|
} else {
|
||||||
@@ -75,8 +73,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
||||||
}
|
}
|
||||||
@@ -1894,8 +1891,7 @@ void segmented_mm(
|
|||||||
return std::make_tuple(false, x);
|
return std::make_tuple(false, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return std::make_tuple(true, x_copy);
|
return std::make_tuple(true, x_copy);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ void RMSNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@@ -107,9 +106,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
if (x.flags().row_contiguous) {
|
if (x.flags().row_contiguous) {
|
||||||
return {x, false};
|
return {x, false};
|
||||||
}
|
}
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
return {x_copy, true};
|
return {x_copy, true};
|
||||||
};
|
};
|
||||||
bool donate_x = inputs[0].is_donatable();
|
bool donate_x = inputs[0].is_donatable();
|
||||||
@@ -241,8 +238,7 @@ void LayerNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@@ -319,8 +315,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
if (x.flags().row_contiguous) {
|
if (x.flags().row_contiguous) {
|
||||||
return {x, false};
|
return {x, false};
|
||||||
}
|
}
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
return {x_copy, true};
|
return {x_copy, true};
|
||||||
};
|
};
|
||||||
bool donate_x = inputs[0].is_donatable();
|
bool donate_x = inputs[0].is_donatable();
|
||||||
|
|||||||
@@ -20,8 +20,7 @@ namespace {
|
|||||||
inline array
|
inline array
|
||||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
if (!x.flags().row_contiguous) {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
} else {
|
} else {
|
||||||
@@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix(
|
|||||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -989,8 +989,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// input for the axes with stride smaller than the minimum reduction
|
// input for the axes with stride smaller than the minimum reduction
|
||||||
// stride.
|
// stride.
|
||||||
if (plan.type == GeneralReduce) {
|
if (plan.type == GeneralReduce) {
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(in_copy, s.index);
|
d.add_temporary(in_copy, s.index);
|
||||||
in = in_copy;
|
in = in_copy;
|
||||||
plan = get_reduction_plan(in, axes_);
|
plan = get_reduction_plan(in, axes_);
|
||||||
|
|||||||
@@ -398,8 +398,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
auto copy_unless = [&copies, &s](
|
auto copy_unless = [&copies, &s](
|
||||||
auto predicate, const array& arr) -> const array& {
|
auto predicate, const array& arr) -> const array& {
|
||||||
if (!predicate(arr)) {
|
if (!predicate(arr)) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
copies.push_back(std::move(arr_copy));
|
copies.push_back(std::move(arr_copy));
|
||||||
return copies.back();
|
return copies.back();
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -30,9 +30,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
|
||||||
in = std::move(arr_copy);
|
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,8 +35,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 26
|
#define MLX_VERSION_MINOR 26
|
||||||
#define MLX_VERSION_PATCH 3
|
#define MLX_VERSION_PATCH 5
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -848,6 +848,106 @@ class Adafactor(Optimizer):
|
|||||||
return parameter - update
|
return parameter - update
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(Optimizer):
|
||||||
|
r"""The Muon optimizer.
|
||||||
|
|
||||||
|
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||||
|
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||||
|
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||||
|
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||||
|
by a different method (e.g., :class:`AdamW`).
|
||||||
|
- For 4D convolutional filters, it works by flattening their last
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float or callable): The learning rate.
|
||||||
|
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||||
|
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||||
|
Default: ``0.01``
|
||||||
|
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||||
|
better performance. Default: ``True``
|
||||||
|
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||||
|
orthogonalization. Default: ``5``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
momentum: float = 0.95,
|
||||||
|
weight_decay: float = 0.01,
|
||||||
|
nesterov: bool = True,
|
||||||
|
ns_steps: int = 5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
|
self.momentum = momentum
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.nesterov = nesterov
|
||||||
|
self.ns_steps = ns_steps
|
||||||
|
|
||||||
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
|
"""Initialize optimizer state"""
|
||||||
|
state["v"] = mx.zeros_like(parameter)
|
||||||
|
|
||||||
|
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||||
|
assert (
|
||||||
|
X.ndim == 2
|
||||||
|
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||||
|
|
||||||
|
if transpose_needed:
|
||||||
|
X = X.T
|
||||||
|
|
||||||
|
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
|
||||||
|
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ X.T
|
||||||
|
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||||
|
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||||
|
|
||||||
|
if transpose_needed:
|
||||||
|
X = X.T
|
||||||
|
return X
|
||||||
|
|
||||||
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
|
"""Performs the Muon parameter update"""
|
||||||
|
|
||||||
|
if self.weight_decay != 0:
|
||||||
|
gradient = gradient + self.weight_decay * parameter
|
||||||
|
|
||||||
|
v = self.momentum * state["v"]
|
||||||
|
v = v + (1 - self.momentum) * gradient
|
||||||
|
state["v"] = v
|
||||||
|
|
||||||
|
if self.nesterov:
|
||||||
|
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||||
|
else:
|
||||||
|
update = v
|
||||||
|
|
||||||
|
lr = self.learning_rate.astype(gradient.dtype)
|
||||||
|
|
||||||
|
if update.ndim >= 2:
|
||||||
|
original_shape = update.shape
|
||||||
|
reshape_needed = update.ndim > 2
|
||||||
|
|
||||||
|
if reshape_needed:
|
||||||
|
update = mx.reshape(update, (update.shape[0], -1))
|
||||||
|
|
||||||
|
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||||
|
|
||||||
|
if reshape_needed:
|
||||||
|
update = mx.reshape(update, original_shape)
|
||||||
|
|
||||||
|
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||||
|
|
||||||
|
return parameter - lr * update
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_norm(grads, max_norm):
|
def clip_grad_norm(grads, max_norm):
|
||||||
"""Clips the global norm of the gradients.
|
"""Clips the global norm of the gradients.
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
auditwheel repair dist/* \
|
auditwheel repair dist/* \
|
||||||
--plat manylinux_2_39_x86_64 \
|
--plat manylinux_2_35_x86_64 \
|
||||||
--exclude libcublas* \
|
--exclude libcublas* \
|
||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|||||||
@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
|
# Transposed c
|
||||||
|
a = mx.ones((10, 5)).T
|
||||||
|
b = mx.ones((5, 5))
|
||||||
|
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||||
|
expected = 1.5 * a + 0.5 * (b @ a)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# Broadcast c
|
||||||
|
a = mx.ones((5, 5))
|
||||||
|
b = mx.ones((5, 5))
|
||||||
|
c = mx.ones((1, 5))
|
||||||
|
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||||
|
expected = 1.5 * c + 0.5 * (a @ b)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||||
|
|||||||
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(xp["x"].shape, x.shape)
|
self.assertEqual(xp["x"].shape, x.shape)
|
||||||
self.assertEqual(optimizer.state["step"], 2)
|
self.assertEqual(optimizer.state["step"], 2)
|
||||||
|
|
||||||
|
def test_muon(self):
|
||||||
|
params = {
|
||||||
|
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||||
|
"second": mx.zeros((3, 3)),
|
||||||
|
"conv": mx.zeros((16, 8, 3, 3)),
|
||||||
|
}
|
||||||
|
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||||
|
|
||||||
|
# Explicit init
|
||||||
|
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||||
|
optim.init(params)
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||||
|
params,
|
||||||
|
optim.state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test update
|
||||||
|
updated_params = optim.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
# Check that shapes are preserved
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, u: p.shape == u.shape,
|
||||||
|
params,
|
||||||
|
updated_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that parameters actually changed
|
||||||
|
self.assertFalse(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, u: mx.array_equal(p, u),
|
||||||
|
params,
|
||||||
|
updated_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with different configurations
|
||||||
|
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||||
|
optim_no_nesterov.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||||
|
optim_no_momentum.apply_gradients(grads, params)
|
||||||
|
|
||||||
def test_compiled_optimizer(self):
|
def test_compiled_optimizer(self):
|
||||||
model = nn.Linear(10, 10)
|
model = nn.Linear(10, 10)
|
||||||
x = mx.random.uniform(shape=(2, 10))
|
x = mx.random.uniform(shape=(2, 10))
|
||||||
|
|||||||
@@ -39,6 +39,14 @@ target_sources(
|
|||||||
linalg_tests.cpp
|
linalg_tests.cpp
|
||||||
${METAL_TEST_SOURCES})
|
${METAL_TEST_SOURCES})
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
# Find the CCCL headers in install dir.
|
||||||
|
target_compile_definitions(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
MLX_CCCL_DIR="${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/cccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
target_link_libraries(tests PRIVATE mlx doctest)
|
target_link_libraries(tests PRIVATE mlx doctest)
|
||||||
doctest_discover_tests(tests)
|
doctest_discover_tests(tests)
|
||||||
add_test(NAME tests COMMAND tests)
|
add_test(NAME tests COMMAND tests)
|
||||||
|
|||||||
Reference in New Issue
Block a user