mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 11:38:06 +08:00
Rename cu::Matmul to CublasGemm (#2488)
This commit is contained in:
parent
ac207ce7aa
commit
dfb5022eab
@ -53,10 +53,10 @@ target_sources(
|
|||||||
|
|
||||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
@ -7,10 +7,12 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
struct CublasPreference {
|
struct CublasPreference {
|
||||||
CublasPreference(Device& device) {
|
CublasPreference(cu::Device& device) {
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
// for Hopper+:
|
// for Hopper+:
|
||||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
@ -33,7 +35,7 @@ struct CublasPreference {
|
|||||||
cublasLtMatmulPreference_t pref_{nullptr};
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
|
||||||
static CublasPreference pref(device);
|
static CublasPreference pref(device);
|
||||||
return pref.pref_;
|
return pref.pref_;
|
||||||
}
|
}
|
||||||
@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
|||||||
return CUBLAS_COMPUTE_64F;
|
return CUBLAS_COMPUTE_64F;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
|
|||||||
return CUDA_C_32F;
|
return CUDA_C_32F;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
|||||||
return desc;
|
return desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::Matmul(
|
} // namespace
|
||||||
Device& device,
|
|
||||||
|
CublasGemm::CublasGemm(
|
||||||
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@ -155,8 +159,8 @@ Matmul::Matmul(
|
|||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::Matmul(
|
CublasGemm::CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@ -171,7 +175,7 @@ Matmul::Matmul(
|
|||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride,
|
int64_t b_batch_stride,
|
||||||
int64_t c_batch_stride)
|
int64_t c_batch_stride)
|
||||||
: Matmul(
|
: CublasGemm(
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@ -190,7 +194,7 @@ Matmul::Matmul(
|
|||||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::~Matmul() {
|
CublasGemm::~CublasGemm() {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||||
@ -198,7 +202,73 @@ Matmul::~Matmul() {
|
|||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_impl(
|
void CublasGemm::run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides) {
|
||||||
|
int batch_count = out.size() / (M_ * N_);
|
||||||
|
if (batch_count / batch_shape.back() > 1) {
|
||||||
|
run_batched(
|
||||||
|
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CublasGemm::run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides,
|
||||||
|
const Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
int batch_count = out.size() / (M_ * N_);
|
||||||
|
if (batch_count / batch_shape.back() > 1) {
|
||||||
|
run_batched(
|
||||||
|
encoder,
|
||||||
|
out,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides,
|
||||||
|
c_batch_strides,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
execute(
|
||||||
|
encoder,
|
||||||
|
out.data<void>(),
|
||||||
|
a.data<void>(),
|
||||||
|
b.data<void>(),
|
||||||
|
c.data<void>(),
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CublasGemm::execute(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
void* out,
|
void* out,
|
||||||
const void* a,
|
const void* a,
|
||||||
@ -256,29 +326,4 @@ void Matmul::run_impl(
|
|||||||
encoder.stream()));
|
encoder.stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run(
|
} // namespace mlx::core
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
array& out,
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
const std::optional<array>& c /* = std::nullopt */,
|
|
||||||
float alpha /* = 1 */,
|
|
||||||
float beta /* = 0 */) {
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
if (c) {
|
|
||||||
encoder.set_input_array(*c);
|
|
||||||
}
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
run_impl(
|
|
||||||
encoder,
|
|
||||||
out.data<void>(),
|
|
||||||
a.data<void>(),
|
|
||||||
b.data<void>(),
|
|
||||||
c ? c->data<void>() : nullptr,
|
|
||||||
alpha,
|
|
||||||
beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
|
@ -5,13 +5,13 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <optional>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
class Matmul {
|
|
||||||
|
class CublasGemm {
|
||||||
public:
|
public:
|
||||||
Matmul(
|
CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@ -25,8 +25,8 @@ class Matmul {
|
|||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride);
|
int64_t b_batch_stride);
|
||||||
|
|
||||||
Matmul(
|
CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@ -42,25 +42,39 @@ class Matmul {
|
|||||||
int64_t b_batch_stride,
|
int64_t b_batch_stride,
|
||||||
int64_t c_batch_stride);
|
int64_t c_batch_stride);
|
||||||
|
|
||||||
~Matmul();
|
~CublasGemm();
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const std::optional<array>& c = std::nullopt,
|
const Shape& batch_shape,
|
||||||
float alpha = 1,
|
const Strides& a_batch_strides,
|
||||||
float beta = 0);
|
const Strides& b_batch_strides);
|
||||||
|
|
||||||
|
void run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides,
|
||||||
|
const Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta);
|
||||||
|
|
||||||
|
private:
|
||||||
void run_batched(
|
void run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides);
|
const Strides& b_batch_strides);
|
||||||
|
|
||||||
void run_batched(
|
void run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
@ -68,15 +82,14 @@ class Matmul {
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides,
|
const Strides& b_batch_strides,
|
||||||
const mlx::core::Strides& c_batch_strides,
|
const Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta);
|
float beta);
|
||||||
|
|
||||||
private:
|
void execute(
|
||||||
void run_impl(
|
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
void* out,
|
void* out,
|
||||||
const void* a,
|
const void* a,
|
||||||
@ -97,4 +110,4 @@ class Matmul {
|
|||||||
cublasLtMatmulHeuristicResult_t heuristic_;
|
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core
|
||||||
|
@ -4,16 +4,16 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
|
|
||||||
void Matmul::run_batched(
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides) {
|
const Strides& b_batch_strides) {
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@ -22,7 +22,7 @@ void Matmul::run_batched(
|
|||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
auto concurrent = encoder.concurrent_context();
|
auto concurrent = encoder.concurrent_context();
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
@ -33,16 +33,16 @@ void Matmul::run_batched(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_batched(
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides,
|
const Strides& b_batch_strides,
|
||||||
const mlx::core::Strides& c_batch_strides,
|
const Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta) {
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
@ -56,7 +56,7 @@ void Matmul::run_batched(
|
|||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
auto concurrent = encoder.concurrent_context();
|
auto concurrent = encoder.concurrent_context();
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
@ -70,4 +70,4 @@ void Matmul::run_batched(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core
|
@ -6,7 +6,9 @@
|
|||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
@ -128,6 +130,10 @@ __global__ void set_addmm_device_pointers_g(
|
|||||||
out_start + item_size * index * batch_stride;
|
out_start + item_size * index * batch_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
||||||
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
@ -139,14 +145,16 @@ void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
|||||||
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_batched(
|
} // namespace
|
||||||
|
|
||||||
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides) {
|
const Strides& b_batch_strides) {
|
||||||
int batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
@ -213,7 +221,7 @@ void Matmul::run_batched(
|
|||||||
auto a_pointers = pointers.data<int8_t*>();
|
auto a_pointers = pointers.data<int8_t*>();
|
||||||
auto b_pointers = a_pointers + batch_count;
|
auto b_pointers = a_pointers + batch_count;
|
||||||
auto out_pointers = b_pointers + batch_count;
|
auto out_pointers = b_pointers + batch_count;
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
reinterpret_cast<void*>(out_pointers),
|
reinterpret_cast<void*>(out_pointers),
|
||||||
reinterpret_cast<void*>(a_pointers),
|
reinterpret_cast<void*>(a_pointers),
|
||||||
@ -221,16 +229,16 @@ void Matmul::run_batched(
|
|||||||
nullptr);
|
nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_batched(
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides,
|
const Strides& b_batch_strides,
|
||||||
const mlx::core::Strides& c_batch_strides,
|
const Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta) {
|
||||||
int batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
@ -306,7 +314,7 @@ void Matmul::run_batched(
|
|||||||
auto b_pointers = a_pointers + batch_count;
|
auto b_pointers = a_pointers + batch_count;
|
||||||
auto c_pointers = b_pointers + batch_count;
|
auto c_pointers = b_pointers + batch_count;
|
||||||
auto out_pointers = c_pointers + batch_count;
|
auto out_pointers = c_pointers + batch_count;
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
reinterpret_cast<void*>(out_pointers),
|
reinterpret_cast<void*>(out_pointers),
|
||||||
reinterpret_cast<void*>(a_pointers),
|
reinterpret_cast<void*>(a_pointers),
|
||||||
@ -316,4 +324,4 @@ void Matmul::run_batched(
|
|||||||
beta);
|
beta);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core
|
@ -97,7 +97,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
cu::Matmul matmul(
|
CublasGemm gemm(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@ -111,14 +111,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
batch_shape.back(),
|
batch_shape.back(),
|
||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||||
if ((batch_count / batch_shape.back()) == 1) {
|
|
||||||
matmul.run(encoder, out, a, b);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
matmul.run_batched(
|
|
||||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -186,7 +179,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
|
|
||||||
cu::Matmul matmul(
|
CublasGemm gemm(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@ -202,12 +195,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
gemm.run(
|
||||||
if ((batch_count / batch_shape.back()) == 1) {
|
|
||||||
matmul.run(encoder, out, a, b, c, alpha_, beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
matmul.run_batched(
|
|
||||||
encoder,
|
encoder,
|
||||||
out,
|
out,
|
||||||
a,
|
a,
|
||||||
|
Loading…
Reference in New Issue
Block a user