mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Rename cu::Matmul to CublasGemm (#2488)
This commit is contained in:
@@ -7,10 +7,12 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
struct CublasPreference {
|
||||
CublasPreference(Device& device) {
|
||||
CublasPreference(cu::Device& device) {
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
@@ -33,7 +35,7 @@ struct CublasPreference {
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
};
|
||||
|
||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
|
||||
static CublasPreference pref(device);
|
||||
return pref.pref_;
|
||||
}
|
||||
@@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
||||
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
|
||||
return CUDA_C_32F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
||||
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
||||
return desc;
|
||||
}
|
||||
|
||||
Matmul::Matmul(
|
||||
Device& device,
|
||||
} // namespace
|
||||
|
||||
CublasGemm::CublasGemm(
|
||||
cu::Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -155,8 +159,8 @@ Matmul::Matmul(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
}
|
||||
|
||||
Matmul::Matmul(
|
||||
Device& device,
|
||||
CublasGemm::CublasGemm(
|
||||
cu::Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -171,7 +175,7 @@ Matmul::Matmul(
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride,
|
||||
int64_t c_batch_stride)
|
||||
: Matmul(
|
||||
: CublasGemm(
|
||||
device,
|
||||
dtype,
|
||||
a_transposed,
|
||||
@@ -190,7 +194,7 @@ Matmul::Matmul(
|
||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
Matmul::~Matmul() {
|
||||
CublasGemm::~CublasGemm() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||
@@ -198,7 +202,73 @@ Matmul::~Matmul() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||
}
|
||||
|
||||
void Matmul::run_impl(
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
return;
|
||||
}
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
c_batch_strides,
|
||||
alpha,
|
||||
beta);
|
||||
return;
|
||||
}
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
c.data<void>(),
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
void CublasGemm::execute(
|
||||
cu::CommandEncoder& encoder,
|
||||
void* out,
|
||||
const void* a,
|
||||
@@ -256,29 +326,4 @@ void Matmul::run_impl(
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
void Matmul::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::optional<array>& c /* = std::nullopt */,
|
||||
float alpha /* = 1 */,
|
||||
float beta /* = 0 */) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
if (c) {
|
||||
encoder.set_input_array(*c);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
|
||||
run_impl(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
c ? c->data<void>() : nullptr,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user