Rename cu::Matmul to CublasGemm (#2488)

This commit is contained in:
Cheng 2025-08-13 09:37:40 +09:00 committed by GitHub
parent ac207ce7aa
commit dfb5022eab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 157 additions and 103 deletions

View File

@ -53,10 +53,10 @@ target_sources(
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
else() else()
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
endif() endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)

View File

@ -7,10 +7,12 @@
#include <fmt/format.h> #include <fmt/format.h>
namespace mlx::core::cu { namespace mlx::core {
namespace {
struct CublasPreference { struct CublasPreference {
CublasPreference(Device& device) { CublasPreference(cu::Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+: // for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
@ -33,7 +35,7 @@ struct CublasPreference {
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
}; };
cublasLtMatmulPreference_t cublas_preference(Device& device) { cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
static CublasPreference pref(device); static CublasPreference pref(device);
return pref.pref_; return pref.pref_;
} }
@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return CUBLAS_COMPUTE_64F; return CUBLAS_COMPUTE_64F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
} }
} }
@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
return CUDA_C_32F; return CUDA_C_32F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
} }
} }
@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
return desc; return desc;
} }
Matmul::Matmul( } // namespace
Device& device,
CublasGemm::CublasGemm(
cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@ -155,8 +159,8 @@ Matmul::Matmul(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
} }
Matmul::Matmul( CublasGemm::CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@ -171,7 +175,7 @@ Matmul::Matmul(
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride, int64_t b_batch_stride,
int64_t c_batch_stride) int64_t c_batch_stride)
: Matmul( : CublasGemm(
device, device,
dtype, dtype,
a_transposed, a_transposed,
@ -190,7 +194,7 @@ Matmul::Matmul(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
} }
Matmul::~Matmul() { CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
@ -198,7 +202,73 @@ Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
} }
void Matmul::run_impl( void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha,
beta);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
alpha,
beta);
}
void CublasGemm::execute(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
void* out, void* out,
const void* a, const void* a,
@ -256,29 +326,4 @@ void Matmul::run_impl(
encoder.stream())); encoder.stream()));
} }
void Matmul::run( } // namespace mlx::core
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <optional>
namespace mlx::core::cu { namespace mlx::core {
class Matmul {
class CublasGemm {
public: public:
Matmul( CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@ -25,8 +25,8 @@ class Matmul {
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride); int64_t b_batch_stride);
Matmul( CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@ -42,25 +42,39 @@ class Matmul {
int64_t b_batch_stride, int64_t b_batch_stride,
int64_t c_batch_stride); int64_t c_batch_stride);
~Matmul(); ~CublasGemm();
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const std::optional<array>& c = std::nullopt, const Shape& batch_shape,
float alpha = 1, const Strides& a_batch_strides,
float beta = 0); const Strides& b_batch_strides);
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides); const Strides& b_batch_strides);
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@ -68,15 +82,14 @@ class Matmul {
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides, const Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides, const Strides& c_batch_strides,
float alpha, float alpha,
float beta); float beta);
private: void execute(
void run_impl(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
void* out, void* out,
const void* a, const void* a,
@ -97,4 +110,4 @@ class Matmul {
cublasLtMatmulHeuristicResult_t heuristic_; cublasLtMatmulHeuristicResult_t heuristic_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core

View File

@ -4,16 +4,16 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core::cu { namespace mlx::core {
void Matmul::run_batched( void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) { const Strides& b_batch_strides) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
@ -22,7 +22,7 @@ void Matmul::run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context(); auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
run_impl( execute(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
@ -33,16 +33,16 @@ void Matmul::run_batched(
} }
} }
void Matmul::run_batched( void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides, const Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides, const Strides& c_batch_strides,
float alpha, float alpha,
float beta) { float beta) {
encoder.set_input_array(a); encoder.set_input_array(a);
@ -56,7 +56,7 @@ void Matmul::run_batched(
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context(); auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
run_impl( execute(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
@ -70,4 +70,4 @@ void Matmul::run_batched(
} }
} }
} // namespace mlx::core::cu } // namespace mlx::core

View File

@ -6,7 +6,9 @@
#include <cooperative_groups.h> #include <cooperative_groups.h>
namespace mlx::core::cu { namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
@ -128,6 +130,10 @@ __global__ void set_addmm_device_pointers_g(
out_start + item_size * index * batch_stride; out_start + item_size * index * batch_stride;
} }
} // namespace cu
namespace {
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
@ -139,14 +145,16 @@ void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
} }
void Matmul::run_batched( } // namespace
void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) { const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
@ -213,7 +221,7 @@ void Matmul::run_batched(
auto a_pointers = pointers.data<int8_t*>(); auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count; auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count; auto out_pointers = b_pointers + batch_count;
run_impl( execute(
encoder, encoder,
reinterpret_cast<void*>(out_pointers), reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers), reinterpret_cast<void*>(a_pointers),
@ -221,16 +229,16 @@ void Matmul::run_batched(
nullptr); nullptr);
} }
void Matmul::run_batched( void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides, const Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides, const Strides& c_batch_strides,
float alpha, float alpha,
float beta) { float beta) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
@ -306,7 +314,7 @@ void Matmul::run_batched(
auto b_pointers = a_pointers + batch_count; auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count; auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count; auto out_pointers = c_pointers + batch_count;
run_impl( execute(
encoder, encoder,
reinterpret_cast<void*>(out_pointers), reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers), reinterpret_cast<void*>(a_pointers),
@ -316,4 +324,4 @@ void Matmul::run_batched(
beta); beta);
} }
} // namespace mlx::core::cu } // namespace mlx::core

View File

@ -97,7 +97,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
cu::Matmul matmul( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@ -111,14 +111,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape.back(), batch_shape.back(),
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -186,7 +179,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
cu::Matmul matmul( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@ -202,12 +195,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
gemm.run(
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
}
matmul.run_batched(
encoder, encoder,
out, out,
a, a,