Compare commits

..

5 Commits

Author SHA1 Message Date
Jagrit Digani
400f8457ea Experimenting with a gemm based on the cuda steel utils 2025-08-14 11:27:50 -07:00
Cheng
dfb5022eab Rename cu::Matmul to CublasGemm (#2488) 2025-08-13 09:37:40 +09:00
Daniel Yeh
ac207ce7aa make code blocks copyable (#2480)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-08-12 12:29:02 -07:00
Abe Leininger
fce53b61d6 Fix reduce sum/prod overflow (#2477) 2025-08-12 00:05:33 -07:00
Angelos Katharopoulos
8ae4a76308 Use CMake <4.1 to avoid the nvpl error (#2489) 2025-08-12 00:03:42 -07:00
17 changed files with 659 additions and 185 deletions

View File

@@ -1,4 +1,5 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
sphinx-copybutton
mlx mlx

View File

@@ -18,6 +18,7 @@ release = version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
extensions = [ extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc", "sphinx.ext.autodoc",
"sphinx.ext.autosummary", "sphinx.ext.autosummary",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",

View File

@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
case uint8: case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8: case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break; break;
case int64: case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break; break;
case float16: case float16:

View File

@@ -24,6 +24,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/steel_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
@@ -53,10 +54,10 @@ target_sources(
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
else() else()
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
endif() endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)

View File

@@ -7,10 +7,12 @@
#include <fmt/format.h> #include <fmt/format.h>
namespace mlx::core::cu { namespace mlx::core {
namespace {
struct CublasPreference { struct CublasPreference {
CublasPreference(Device& device) { CublasPreference(cu::Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+: // for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
@@ -33,7 +35,7 @@ struct CublasPreference {
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
}; };
cublasLtMatmulPreference_t cublas_preference(Device& device) { cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
static CublasPreference pref(device); static CublasPreference pref(device);
return pref.pref_; return pref.pref_;
} }
@@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return CUBLAS_COMPUTE_64F; return CUBLAS_COMPUTE_64F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
} }
} }
@@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
return CUDA_C_32F; return CUDA_C_32F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
} }
} }
@@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
return desc; return desc;
} }
Matmul::Matmul( } // namespace
Device& device,
CublasGemm::CublasGemm(
cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -155,8 +159,8 @@ Matmul::Matmul(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
} }
Matmul::Matmul( CublasGemm::CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -171,7 +175,7 @@ Matmul::Matmul(
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride, int64_t b_batch_stride,
int64_t c_batch_stride) int64_t c_batch_stride)
: Matmul( : CublasGemm(
device, device,
dtype, dtype,
a_transposed, a_transposed,
@@ -190,7 +194,7 @@ Matmul::Matmul(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
} }
Matmul::~Matmul() { CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
@@ -198,7 +202,73 @@ Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
} }
void Matmul::run_impl( void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha,
beta);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
alpha,
beta);
}
void CublasGemm::execute(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
void* out, void* out,
const void* a, const void* a,
@@ -256,29 +326,4 @@ void Matmul::run_impl(
encoder.stream())); encoder.stream()));
} }
void Matmul::run( } // namespace mlx::core
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <optional>
namespace mlx::core::cu { namespace mlx::core {
class Matmul {
class CublasGemm {
public: public:
Matmul( CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -25,8 +25,8 @@ class Matmul {
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride); int64_t b_batch_stride);
Matmul( CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -42,25 +42,39 @@ class Matmul {
int64_t b_batch_stride, int64_t b_batch_stride,
int64_t c_batch_stride); int64_t c_batch_stride);
~Matmul(); ~CublasGemm();
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const std::optional<array>& c = std::nullopt, const Shape& batch_shape,
float alpha = 1, const Strides& a_batch_strides,
float beta = 0); const Strides& b_batch_strides);
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides); const Strides& b_batch_strides);
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -68,15 +82,14 @@ class Matmul {
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides, const Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides, const Strides& c_batch_strides,
float alpha, float alpha,
float beta); float beta);
private: void execute(
void run_impl(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
void* out, void* out,
const void* a, const void* a,
@@ -97,4 +110,4 @@ class Matmul {
cublasLtMatmulHeuristicResult_t heuristic_; cublasLtMatmulHeuristicResult_t heuristic_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core

View File

@@ -4,16 +4,16 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core::cu { namespace mlx::core {
void Matmul::run_batched( void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) { const Strides& b_batch_strides) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -22,7 +22,7 @@ void Matmul::run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context(); auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
run_impl( execute(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -33,16 +33,16 @@ void Matmul::run_batched(
} }
} }
void Matmul::run_batched( void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides, const Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides, const Strides& c_batch_strides,
float alpha, float alpha,
float beta) { float beta) {
encoder.set_input_array(a); encoder.set_input_array(a);
@@ -56,7 +56,7 @@ void Matmul::run_batched(
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context(); auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
run_impl( execute(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -70,4 +70,4 @@ void Matmul::run_batched(
} }
} }
} // namespace mlx::core::cu } // namespace mlx::core

View File

@@ -6,7 +6,9 @@
#include <cooperative_groups.h> #include <cooperative_groups.h>
namespace mlx::core::cu { namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
@@ -128,6 +130,10 @@ __global__ void set_addmm_device_pointers_g(
out_start + item_size * index * batch_stride; out_start + item_size * index * batch_stride;
} }
} // namespace cu
namespace {
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
@@ -139,14 +145,16 @@ void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
} }
void Matmul::run_batched( } // namespace
void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) { const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
@@ -213,7 +221,7 @@ void Matmul::run_batched(
auto a_pointers = pointers.data<int8_t*>(); auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count; auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count; auto out_pointers = b_pointers + batch_count;
run_impl( execute(
encoder, encoder,
reinterpret_cast<void*>(out_pointers), reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers), reinterpret_cast<void*>(a_pointers),
@@ -221,16 +229,16 @@ void Matmul::run_batched(
nullptr); nullptr);
} }
void Matmul::run_batched( void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides, const Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides, const Strides& c_batch_strides,
float alpha, float alpha,
float beta) { float beta) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
@@ -306,7 +314,7 @@ void Matmul::run_batched(
auto b_pointers = a_pointers + batch_count; auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count; auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count; auto out_pointers = c_pointers + batch_count;
run_impl( execute(
encoder, encoder,
reinterpret_cast<void*>(out_pointers), reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers), reinterpret_cast<void*>(a_pointers),
@@ -316,4 +324,4 @@ void Matmul::run_batched(
beta); beta);
} }
} // namespace mlx::core::cu } // namespace mlx::core

View File

@@ -0,0 +1,301 @@
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/gemms/steel_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <numeric>
#include <cooperative_groups.h>
#include "mlx/backend/cuda/steel/gemm.cuh"
#include "mlx/backend/cuda/steel/mma.cuh"
#include "mlx/backend/cuda/steel/tiles.cuh"
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
struct GemmParams {
int M;
int N;
int K;
int lda;
int ldb;
int ldd;
int NblockM;
int NblockN;
int NblockK;
};
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int SL,
int Nstages>
__global__ void kernel_steel_gemm(
const T* a,
const T* b,
T* d,
__grid_constant__ const GemmParams params) {
const int bM_idx = (blockIdx.y << SL) + (blockIdx.x & ((1 << SL) - 1));
const int bN_idx = blockIdx.x >> SL;
if (params.NblockN <= bN_idx || params.NblockM <= bM_idx) {
return;
}
const int d_row = bM_idx * BM;
const int d_col = bN_idx * BN;
const size_t d_row_long = size_t(d_row);
const size_t d_col_long = size_t(d_col);
a += transpose_a ? d_row_long : d_row_long * params.K;
b += transpose_b ? d_col_long * params.K : d_col_long;
d += d_row_long * params.ldd + d_col_long;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
const int wm = warp_idx / WN;
const int wn = warp_idx % WN;
constexpr int SM = BM / WM;
constexpr int SN = BN / WN;
constexpr int SK = BK;
constexpr int TK = SK / 16;
constexpr int NUM_WARPS = WM * WN;
// Allocate shared memory
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&as)[Nstages] =
*(SharedTile<T, BM, BK>(*)[Nstages])(&shmem[0]);
SharedTile<T, BN, BK>(&bs)[Nstages] = *(SharedTile<T, BN, BK>(*)[Nstages])(
&shmem[sizeof(T) * Nstages * BM * BK]);
// Allocate registers for the MMA
RegisterTile<float, SM, SN> C;
RegisterTile<T, SM, 16> A[TK];
RegisterTile<T, SN, 16> B[TK];
// Zero the accumulators
C.fill(0);
// Start gmem -> smem copies
int k_block_read = 0;
MLX_UNROLL
for (int bk = 0; bk < (Nstages - 1); bk++) {
load_async<NUM_WARPS>(
as[bk], as[bk].base_addr(), a + k_block_read, params.K);
load_async<NUM_WARPS>(
bs[bk], bs[bk].base_addr(), b + k_block_read, params.K);
k_block_read += BK;
cp_async_commit();
}
int smem_pipe_read = 0;
int smem_pipe_write = Nstages - 1;
// Wait till only 1 remains laoding
cp_async_wait<1>();
block.sync();
const int offset_m = wm * SM;
const int offset_n = wn * SN;
// Start smem -> register copy
A[0].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
lane_idx / 16 * 8);
B[0].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
lane_idx / 16 * 8);
// Main loop
for (int kb = 0; kb < params.NblockK; kb++) {
// Prepare next registers
{
A[1].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
16 + lane_idx / 16 * 8);
B[1].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
16 + lane_idx / 16 * 8);
}
// Prepare next smem
if ((kb + Nstages - 1) < params.NblockK) {
load_async<NUM_WARPS>(
as[smem_pipe_write],
as[smem_pipe_write].base_addr(),
a + k_block_read,
params.K);
load_async<NUM_WARPS>(
bs[smem_pipe_write],
bs[smem_pipe_write].base_addr(),
b + k_block_read,
params.K);
}
k_block_read += BK;
cp_async_commit();
smem_pipe_write = smem_pipe_read;
smem_pipe_read = smem_pipe_read + 1;
smem_pipe_read = (smem_pipe_read == Nstages) ? 0 : smem_pipe_read;
// Do current gemm
mma_t(C, A[0], B[0]);
// Do wait for next register
cp_async_wait<1>();
block.sync();
// Prepare next register (smem_pipe_read has moved to the next)
{
A[0].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
lane_idx / 16 * 8);
B[0].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
lane_idx / 16 * 8);
}
// Do current gemm
mma_t(C, A[1], B[1]);
}
// Wait and clear
cp_async_wait_all();
block.sync();
C.store_global(d, params.ldd, offset_m, offset_n);
}
} // namespace cu
void dispatch_steel_gemm(
const Stream& s,
cu::CommandEncoder& encoder,
const array& a,
const array& b,
array& d,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool a_transposed,
bool b_transposed) {
using DataType = cuda_type_t<float16_t>;
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(d);
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int SL = 0;
constexpr int Nstages = 3;
constexpr uint32_t smem_bytes = BK * (BM + BN) * Nstages * sizeof(DataType);
const int NblockM = (M + BM - 1) / BM;
const int NblockN = (N + BN - 1) / BN;
const int NblockK = (K + BK - 1) / BK;
cu::GemmParams params{
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* int NblockM = */ NblockM,
/* int NblockN = */ NblockN,
/* int NblockK = */ NblockK,
};
// Prepare launch grid params
int tile = 1 << SL;
int tm = (NblockM + tile - 1) / tile;
int tn = NblockN * tile;
dim3 grid_dim(tn, tm, 1);
dim3 block_dim(32 * WM * WN, 1, 1);
dispatch_bool(a_transposed, [&](auto ta_) {
dispatch_bool(b_transposed, [&](auto tb_) {
constexpr bool ta = ta_.value;
constexpr bool tb = tb_.value;
auto kernel = cu::ab_t_aligned<DataType, BM, BN, BK>;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
smem_bytes,
a.data<DataType>(),
b.data<DataType>(),
d.data<DataType>(),
N,
K);
// auto kernel = cu::kernel_steel_gemm<DataType, BM, BN, BK, WM, WN, ta,
// tb, SL, Nstages>;
// cudaFuncSetAttribute(kernel,
// cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
// encoder.add_kernel_node(
// kernel,
// grid_dim,
// block_dim,
// smem_bytes,
// a.data<DataType>(),
// b.data<DataType>(),
// d.data<DataType>(),
// params);
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,27 @@
#pragma once
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <numeric>
namespace mlx::core {
void dispatch_steel_gemm(
const Stream& s,
cu::CommandEncoder& encoder,
const array& a,
const array& b,
array& d,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool a_transposed,
bool b_transposed);
} // namespace mlx::core

View File

@@ -7,6 +7,8 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/backend/cuda/gemms/steel_gemm.h"
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <numeric> #include <numeric>
@@ -95,9 +97,27 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
if (out.dtype() == float16 && batch_count == 1 && !a_transposed &&
b_transposed) {
return dispatch_steel_gemm(
/* const Stream& s = */ s,
/* cu::CommandEncoder& encoder = */ encoder,
/* const array& a = */ a,
/* const array& b = */ b,
/* array& d = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ N,
/* bool a_transposed = */ a_transposed,
/* bool b_transposed = */ b_transposed);
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
cu::Matmul matmul( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@@ -111,14 +131,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape.back(), batch_shape.back(),
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -186,7 +199,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
cu::Matmul matmul( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@@ -202,12 +215,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
gemm.run(
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
}
matmul.run_batched(
encoder, encoder,
out, out,
a, a,

View File

@@ -143,85 +143,87 @@ struct Tile16x16 {
} }
}; };
/** // /**
* A simple container of multiple Tile16x16. // * A simple container of multiple Tile16x16.
* // *
* Provides utility functions for loading and manipulating collections of basic // * Provides utility functions for loading and manipulating collections of
* tiles. // basic
*/ // * tiles.
template <typename T, int ROWS_, int COLS_> // */
struct RegisterTile { // template <typename T, int ROWS_, int COLS_>
static constexpr int ROWS = ROWS_; // struct RegisterTile {
static constexpr int COLS = COLS_; // static constexpr int ROWS = ROWS_;
static constexpr int TILES_X = COLS / 16; // static constexpr int COLS = COLS_;
static constexpr int TILES_Y = ROWS / 16; // static constexpr int TILES_X = COLS / 16;
// static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y]; // Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) { // __device__ inline void fill(T v) {
MLX_UNROLL // MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) { // for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL // MLX_UNROLL
for (int j = 0; j < TILES_X; j++) { // for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v); // data[i * TILES_X + j].fill(v);
} // }
} // }
} // }
template <typename Tile> // template <typename Tile>
__device__ __forceinline__ void // __device__ __forceinline__ void
load(Tile& tile, uint32_t base_address, int row, int col) { // load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL // MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) { // for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL // MLX_UNROLL
for (int j = 0; j < TILES_X; j++) { // for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load( // data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16)); // tile.loc(base_address, row + i * 16, col + j * 16));
} // }
} // }
} // }
template <typename Tile, typename F> // template <typename Tile, typename F>
__device__ __forceinline__ void // __device__ __forceinline__ void
load(Tile& tile, F f, uint32_t base_address, int row, int col) { // load(Tile& tile, F f, uint32_t base_address, int row, int col) {
MLX_UNROLL // MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) { // for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL // MLX_UNROLL
for (int j = 0; j < TILES_X; j++) { // for (int j = 0; j < TILES_X; j++) {
f(data[i * TILES_X + j], // f(data[i * TILES_X + j],
tile, // tile,
base_address, // base_address,
row + i * 16, // row + i * 16,
col + j * 16); // col + j * 16);
} // }
} // }
} // }
template <typename U> // template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) { // __device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL // MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) { // for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL // MLX_UNROLL
for (int j = 0; j < TILES_X; j++) { // for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global( // data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N); // x + (row + i * 16) * N + col + j * 16, N);
} // }
} // }
} // }
template <typename U> // template <typename U>
__device__ inline void // __device__ inline void
store_global_safe(U* x, int N, int row, int col, int max_rows) { // store_global_safe(U* x, int N, int row, int col, int max_rows) {
MLX_UNROLL // MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) { // for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL // MLX_UNROLL
for (int j = 0; j < TILES_X; j++) { // for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global_safe( // data[i * TILES_X + j].store_global_safe(
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16); // x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i *
} // 16);
} // }
} // }
}; // }
// };
/** /**
* A simple container of multiple Tile16x16. * A simple container of multiple Tile16x16.

View File

@@ -134,6 +134,10 @@ instantiate_and_or(and, And)
instantiate_and_or(or, Or) instantiate_and_or(or, Or)
#define instantiate_sum_prod(name, op) \ #define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \ instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \ instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \

View File

@@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
const std::string& op_name) { const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") { if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) { if (issubdtype(in.dtype(), integer)) {
switch (in.dtype().size()) { switch (in.dtype()) {
case 1: case uint8:
return {uint8, uint32};
case uint16:
return {uint16, uint32};
case uint32:
return {uint32, uint32};
case uint64:
return {uint64, uint64};
case int8:
return {int8, int32}; return {int8, int32};
case 2: case int16:
return {int16, int32}; return {int16, int32};
case 4: case int32:
return {int32, int32}; return {int32, int32};
case 8: case int64:
return {int64, int64}; return {int64, int64};
default:
throw std::runtime_error("Unsupported integer type");
} }
} }
if (in.dtype() == bool_) { if (in.dtype() == bool_) {

View File

@@ -2,6 +2,6 @@
requires = [ requires = [
"setuptools>=80", "setuptools>=80",
"nanobind==2.4.0", "nanobind==2.4.0",
"cmake>=3.25", "cmake>=3.25,<4.1",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1); CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
} }
// sum and prod overflow
{
auto a = full({256, 2, 2}, 1u, uint8);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
a = full({65535, 2, 2}, 1u, uint16);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
}
}
TEST_CASE("test gpu reduce with axes") {
// reducing only some axes and irregular layouts // reducing only some axes and irregular layouts
{ {
array a(1.0f); array a(1.0f);

View File

@@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>()); CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
} }
// Test unsigned sum
{
const int num_elems = 1000;
auto x = astype(full({num_elems}, 255), uint8);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
x = astype(full({num_elems}, 65535), uint16);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
x = full({3, 3, 3}, 10000, uint32);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
x = full({3, 3, 3}, 10000, uint64);
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
}
// Test prod // Test prod
{ {
auto x = array({}); auto x = array({});
@@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>()); CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
} }
// Test unsigned prod
{
auto x = array({255, 255}, {2}, uint8);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
x = array({65535, 2}, {2}, uint16);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
x = array({100000, 2}, {2}, uint32);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
x = array({100000, 2}, {2}, uint64);
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
}
// Test all // Test all
{ {
auto x = array({}); auto x = array({});