Rename cu::Matmul to CublasGemm (#2488)

This commit is contained in:
Cheng
2025-08-13 09:37:40 +09:00
committed by GitHub
parent ac207ce7aa
commit dfb5022eab
6 changed files with 157 additions and 103 deletions

View File

@@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h"
#include <cublasLt.h>
#include <optional>
namespace mlx::core::cu {
class Matmul {
namespace mlx::core {
class CublasGemm {
public:
Matmul(
Device& device,
CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -25,8 +25,8 @@ class Matmul {
int64_t a_batch_stride,
int64_t b_batch_stride);
Matmul(
Device& device,
CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -42,25 +42,39 @@ class Matmul {
int64_t b_batch_stride,
int64_t c_batch_stride);
~Matmul();
~CublasGemm();
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c = std::nullopt,
float alpha = 1,
float beta = 0);
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
@@ -68,15 +82,14 @@ class Matmul {
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_impl(
void execute(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -97,4 +110,4 @@ class Matmul {
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace mlx::core::cu
} // namespace mlx::core