mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
CUDA backend: matmul (#2241)
This commit is contained in:
parent
c6a20b427a
commit
24f89173d1
78
mlx/backend/common/matmul.h
Normal file
78
mlx/backend/common/matmul.h
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||||
|
const array& a,
|
||||||
|
const array& b) {
|
||||||
|
// Get and check the shape for the batched dims
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||||
|
if (A_bshape != B_bshape) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||||
|
<< a.shape() << ", B " << b.shape() << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] =
|
||||||
|
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||||
|
|
||||||
|
auto a_batch_strides = batch_strides[0];
|
||||||
|
auto b_batch_strides = batch_strides[1];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
a_batch_strides.push_back(0);
|
||||||
|
b_batch_strides.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||||
|
collapse_batches(const array& a, const array& b, const array& c) {
|
||||||
|
// Get and check the shape for the batched dims
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||||
|
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||||
|
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
||||||
|
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||||
|
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||||
|
|
||||||
|
auto A_batch_stride = batch_strides[0];
|
||||||
|
auto B_batch_stride = batch_strides[1];
|
||||||
|
auto C_batch_stride = batch_strides[2];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
A_batch_stride.push_back(0);
|
||||||
|
B_batch_stride.push_back(0);
|
||||||
|
C_batch_stride.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(
|
||||||
|
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -12,6 +12,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
@ -53,6 +54,9 @@ target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
|||||||
find_package(CUDAToolkit REQUIRED)
|
find_package(CUDAToolkit REQUIRED)
|
||||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
# Use cublasLt.
|
||||||
|
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||||
|
|
||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
@ -34,14 +34,26 @@ CommandEncoder& DeviceStream::get_encoder() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Device::Device(int device) : device_(device) {
|
Device::Device(int device) : device_(device) {
|
||||||
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
|
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
|
||||||
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
|
&compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_));
|
||||||
// Validate the requirements of device.
|
// Validate the requirements of device.
|
||||||
int attr = 0;
|
int attr = 0;
|
||||||
cudaDeviceGetAttribute(&attr, cudaDevAttrConcurrentManagedAccess, device_);
|
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||||
|
&attr, cudaDevAttrConcurrentManagedAccess, device_));
|
||||||
if (attr != 1) {
|
if (attr != 1) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Device {} does not support synchronization in managed memory.",
|
"Device {} does not support synchronization in managed memory.",
|
||||||
device_));
|
device_));
|
||||||
}
|
}
|
||||||
|
// The cublasLt handle is used by matmul.
|
||||||
|
make_current();
|
||||||
|
cublasLtCreate(<_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Device::~Device() {
|
||||||
|
cublasLtDestroy(lt_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::make_current() {
|
void Device::make_current() {
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -46,6 +47,7 @@ class DeviceStream {
|
|||||||
class Device {
|
class Device {
|
||||||
public:
|
public:
|
||||||
explicit Device(int device);
|
explicit Device(int device);
|
||||||
|
~Device();
|
||||||
|
|
||||||
Device(const Device&) = delete;
|
Device(const Device&) = delete;
|
||||||
Device& operator=(const Device&) = delete;
|
Device& operator=(const Device&) = delete;
|
||||||
@ -58,9 +60,21 @@ class Device {
|
|||||||
int cuda_device() const {
|
int cuda_device() const {
|
||||||
return device_;
|
return device_;
|
||||||
}
|
}
|
||||||
|
int compute_capability_major() const {
|
||||||
|
return compute_capability_major_;
|
||||||
|
}
|
||||||
|
int compute_capability_minor() const {
|
||||||
|
return compute_capability_minor_;
|
||||||
|
}
|
||||||
|
cublasLtHandle_t lt_handle() const {
|
||||||
|
return lt_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_;
|
int device_;
|
||||||
|
int compute_capability_major_;
|
||||||
|
int compute_capability_minor_;
|
||||||
|
cublasLtHandle_t lt_;
|
||||||
std::unordered_map<int, DeviceStream> streams_;
|
std::unordered_map<int, DeviceStream> streams_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
474
mlx/backend/cuda/matmul.cpp
Normal file
474
mlx/backend/cuda/matmul.cpp
Normal file
@ -0,0 +1,474 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/matmul.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cublasLt.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||||
|
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
// TODO: Use cublasGetStatusString when it is widely available.
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class MatMul {
|
||||||
|
public:
|
||||||
|
MatMul(
|
||||||
|
Device& device,
|
||||||
|
Dtype dtype,
|
||||||
|
bool a_transposed,
|
||||||
|
uint64_t a_rows,
|
||||||
|
uint64_t a_cols,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
uint64_t b_rows,
|
||||||
|
uint64_t b_cols,
|
||||||
|
int64_t ldb,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t a_batch_stride,
|
||||||
|
int64_t b_batch_stride) {
|
||||||
|
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||||
|
|
||||||
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||||
|
&matmul_desc_, dtype_to_compute_type(dtype), type));
|
||||||
|
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||||
|
&pointer_mode,
|
||||||
|
sizeof(int32_t)));
|
||||||
|
cublasOperation_t op = CUBLAS_OP_N;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||||
|
&op,
|
||||||
|
sizeof(cublasOperation_t)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||||
|
matmul_desc_,
|
||||||
|
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||||
|
&op,
|
||||||
|
sizeof(cublasOperation_t)));
|
||||||
|
|
||||||
|
a_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||||
|
b_desc_ = create_matrix_layout(
|
||||||
|
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||||
|
out_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
|
|
||||||
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
|
// for Hopper+:
|
||||||
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
|
uint64_t MiB = 1024 * 1024;
|
||||||
|
uint64_t workspace_size =
|
||||||
|
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
|
||||||
|
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
|
||||||
|
pref_,
|
||||||
|
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||||
|
&workspace_size,
|
||||||
|
sizeof(uint64_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
MatMul(
|
||||||
|
Device& device,
|
||||||
|
Dtype dtype,
|
||||||
|
bool a_transposed,
|
||||||
|
uint64_t a_rows,
|
||||||
|
uint64_t a_cols,
|
||||||
|
int64_t lda,
|
||||||
|
bool b_transposed,
|
||||||
|
uint64_t b_rows,
|
||||||
|
uint64_t b_cols,
|
||||||
|
int64_t ldb,
|
||||||
|
bool c_transposed,
|
||||||
|
int64_t ldc,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t a_batch_stride,
|
||||||
|
int64_t b_batch_stride,
|
||||||
|
int64_t c_batch_stride)
|
||||||
|
: MatMul(
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
a_transposed,
|
||||||
|
a_rows,
|
||||||
|
a_cols,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
b_rows,
|
||||||
|
b_cols,
|
||||||
|
ldb,
|
||||||
|
batch_count,
|
||||||
|
a_batch_stride,
|
||||||
|
b_batch_stride) {
|
||||||
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
|
c_desc_ = create_matrix_layout(
|
||||||
|
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
~MatMul() {
|
||||||
|
cublasLtMatrixLayoutDestroy(a_desc_);
|
||||||
|
cublasLtMatrixLayoutDestroy(b_desc_);
|
||||||
|
cublasLtMatrixLayoutDestroy(c_desc_);
|
||||||
|
cublasLtMatrixLayoutDestroy(out_desc_);
|
||||||
|
cublasLtMatmulDescDestroy(matmul_desc_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
void* out,
|
||||||
|
void* a,
|
||||||
|
void* b,
|
||||||
|
void* c = nullptr,
|
||||||
|
float alpha = 1,
|
||||||
|
float beta = 0) {
|
||||||
|
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
int ret = 0;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
|
||||||
|
encoder.device().lt_handle(),
|
||||||
|
matmul_desc_,
|
||||||
|
a_desc_,
|
||||||
|
b_desc_,
|
||||||
|
out_desc_,
|
||||||
|
out_desc_,
|
||||||
|
pref_,
|
||||||
|
1,
|
||||||
|
&heuristic_,
|
||||||
|
&ret));
|
||||||
|
if (ret == 0) {
|
||||||
|
throw std::runtime_error("Can not find algorithm for matmul.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array workspace(
|
||||||
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
|
int8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
|
||||||
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
|
encoder.device().lt_handle(),
|
||||||
|
matmul_desc_,
|
||||||
|
&alpha,
|
||||||
|
a,
|
||||||
|
a_desc_,
|
||||||
|
b,
|
||||||
|
b_desc_,
|
||||||
|
&beta,
|
||||||
|
c ? c : out,
|
||||||
|
c ? c_desc_ : out_desc_,
|
||||||
|
out,
|
||||||
|
out_desc_,
|
||||||
|
&heuristic_.algo,
|
||||||
|
workspace.data<void>(),
|
||||||
|
workspace.nbytes(),
|
||||||
|
stream));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case uint8:
|
||||||
|
case uint16:
|
||||||
|
case int8:
|
||||||
|
case int16:
|
||||||
|
case int32:
|
||||||
|
return CUBLAS_COMPUTE_32I;
|
||||||
|
case float16:
|
||||||
|
case bfloat16:
|
||||||
|
return CUBLAS_COMPUTE_16F;
|
||||||
|
case float32:
|
||||||
|
return CUBLAS_COMPUTE_32F;
|
||||||
|
case float64:
|
||||||
|
case complex64:
|
||||||
|
return CUBLAS_COMPUTE_64F;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case uint8:
|
||||||
|
return CUDA_R_8U;
|
||||||
|
case uint16:
|
||||||
|
return CUDA_R_16U;
|
||||||
|
case int8:
|
||||||
|
return CUDA_R_8I;
|
||||||
|
case int16:
|
||||||
|
return CUDA_R_16I;
|
||||||
|
case int32:
|
||||||
|
return CUDA_R_32I;
|
||||||
|
case float16:
|
||||||
|
return CUDA_R_16F;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDA_R_16BF;
|
||||||
|
case float32:
|
||||||
|
return CUDA_R_32F;
|
||||||
|
case float64:
|
||||||
|
return CUDA_R_64F;
|
||||||
|
case complex64:
|
||||||
|
return CUDA_C_32F;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatrixLayout_t create_matrix_layout(
|
||||||
|
cudaDataType_t type,
|
||||||
|
uint64_t rows,
|
||||||
|
uint64_t cols,
|
||||||
|
bool transposed,
|
||||||
|
int64_t ld,
|
||||||
|
int32_t batch_count,
|
||||||
|
int64_t batch_stride) {
|
||||||
|
cublasLtMatrixLayout_t desc;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||||
|
cublasLtOrder_t order =
|
||||||
|
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
||||||
|
if (batch_count > 1) {
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||||
|
&batch_count,
|
||||||
|
sizeof(int32_t)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||||
|
&batch_stride,
|
||||||
|
sizeof(int64_t)));
|
||||||
|
}
|
||||||
|
return desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||||
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t a_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t b_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t c_desc_{nullptr};
|
||||||
|
cublasLtMatrixLayout_t out_desc_{nullptr};
|
||||||
|
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::tuple<bool, int64_t, array>
|
||||||
|
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
|
||||||
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (sty == 1 && stx == arr.shape(-1)) {
|
||||||
|
return std::make_tuple(false, stx, arr);
|
||||||
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, arr);
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Matmul::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a_pre = inputs[0];
|
||||||
|
auto& b_pre = inputs[1];
|
||||||
|
// Return 0s if either input is empty.
|
||||||
|
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||||
|
array zero(0, a_pre.dtype());
|
||||||
|
encoder.add_temporary(zero);
|
||||||
|
fill_gpu(zero, out, s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Init checks and prep
|
||||||
|
|
||||||
|
int M = a_pre.shape(-2);
|
||||||
|
int N = b_pre.shape(-1);
|
||||||
|
int K = a_pre.shape(-1);
|
||||||
|
|
||||||
|
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||||
|
// the arrays
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||||
|
|
||||||
|
for (auto& temp : copies) {
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Check and collapse batch dimensions
|
||||||
|
|
||||||
|
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||||
|
|
||||||
|
auto batch_count = out.size() / (M * N);
|
||||||
|
|
||||||
|
// Collapse batches into M if needed
|
||||||
|
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||||
|
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||||
|
b_batch_strides.back() == 0) {
|
||||||
|
M *= batch_shape.back();
|
||||||
|
batch_count = 1;
|
||||||
|
|
||||||
|
a_batch_strides = {0};
|
||||||
|
b_batch_strides = {0};
|
||||||
|
batch_shape = {1};
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Invoke cublasLt
|
||||||
|
|
||||||
|
cu::MatMul matmul(
|
||||||
|
encoder.device(),
|
||||||
|
a.dtype(),
|
||||||
|
a_transposed,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
ldb,
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back());
|
||||||
|
|
||||||
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
|
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
|
b.data<int8_t>() + b.itemsize() * b_it.loc);
|
||||||
|
a_it.step();
|
||||||
|
b_it.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("AddMM::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
assert(inputs.size() == 3);
|
||||||
|
auto& a_pre = inputs[0];
|
||||||
|
auto& b_pre = inputs[1];
|
||||||
|
auto& c_pre = inputs[2];
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Init checks and prep
|
||||||
|
|
||||||
|
int M = a_pre.shape(-2);
|
||||||
|
int N = b_pre.shape(-1);
|
||||||
|
int K = a_pre.shape(-1);
|
||||||
|
|
||||||
|
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||||
|
// the arrays
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
|
||||||
|
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
|
||||||
|
|
||||||
|
for (auto& temp : copies) {
|
||||||
|
encoder.add_temporary(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Check and collapse batch dimensions
|
||||||
|
|
||||||
|
auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] =
|
||||||
|
collapse_batches(a, b, c);
|
||||||
|
|
||||||
|
auto batch_count = out.size() / (M * N);
|
||||||
|
|
||||||
|
// Collapse batches into M if needed
|
||||||
|
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||||
|
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||||
|
c_batch_strides.back() == M * c.strides()[c.ndim() - 2] &&
|
||||||
|
b_batch_strides.back() == 0) {
|
||||||
|
M *= batch_shape.back();
|
||||||
|
batch_count = 1;
|
||||||
|
|
||||||
|
a_batch_strides = {0};
|
||||||
|
b_batch_strides = {0};
|
||||||
|
c_batch_strides = {0};
|
||||||
|
batch_shape = {1};
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Invoke cublasLt
|
||||||
|
|
||||||
|
cu::MatMul matmul(
|
||||||
|
encoder.device(),
|
||||||
|
a.dtype(),
|
||||||
|
a_transposed,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
b_transposed,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
ldb,
|
||||||
|
c_transposed,
|
||||||
|
ldc,
|
||||||
|
batch_shape.back(),
|
||||||
|
a_batch_strides.back(),
|
||||||
|
b_batch_strides.back(),
|
||||||
|
c_batch_strides.back());
|
||||||
|
|
||||||
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
|
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
|
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||||
|
c.data<int8_t>() + c.itemsize() * c_it.loc,
|
||||||
|
alpha_,
|
||||||
|
beta_);
|
||||||
|
a_it.step();
|
||||||
|
b_it.step();
|
||||||
|
c_it.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -73,7 +73,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
|
|
||||||
NO_GPU(Abs)
|
NO_GPU(Abs)
|
||||||
NO_GPU(Add)
|
NO_GPU(Add)
|
||||||
NO_GPU(AddMM)
|
|
||||||
NO_GPU(ArcCos)
|
NO_GPU(ArcCos)
|
||||||
NO_GPU(ArcCosh)
|
NO_GPU(ArcCosh)
|
||||||
NO_GPU(ArcSin)
|
NO_GPU(ArcSin)
|
||||||
@ -124,7 +123,6 @@ NO_GPU(LogicalOr)
|
|||||||
NO_GPU(LogAddExp)
|
NO_GPU(LogAddExp)
|
||||||
NO_GPU(LogSumExp)
|
NO_GPU(LogSumExp)
|
||||||
NO_GPU_MULTI(LUF)
|
NO_GPU_MULTI(LUF)
|
||||||
NO_GPU(Matmul)
|
|
||||||
NO_GPU(Maximum)
|
NO_GPU(Maximum)
|
||||||
NO_GPU(Minimum)
|
NO_GPU(Minimum)
|
||||||
NO_GPU(Multiply)
|
NO_GPU(Multiply)
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/broadcasting.h"
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/matmul.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
@ -21,69 +21,6 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline auto collapse_batches(const array& a, const array& b) {
|
|
||||||
// Get and check the shape for the batched dims
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
|
||||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
|
||||||
if (A_bshape != B_bshape) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
|
|
||||||
<< a.shape() << ", B " << b.shape() << ".";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
|
||||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
|
||||||
|
|
||||||
auto [batch_shape, batch_strides] =
|
|
||||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
|
||||||
|
|
||||||
auto A_batch_stride = batch_strides[0];
|
|
||||||
auto B_batch_stride = batch_strides[1];
|
|
||||||
|
|
||||||
if (batch_shape.empty()) {
|
|
||||||
batch_shape.push_back(1);
|
|
||||||
A_batch_stride.push_back(0);
|
|
||||||
B_batch_stride.push_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
|
||||||
// Get and check the shape for the batched dims
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
|
||||||
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
|
|
||||||
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
|
|
||||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
|
|
||||||
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
|
||||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
|
||||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
|
||||||
|
|
||||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
|
||||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
|
||||||
|
|
||||||
auto A_batch_stride = batch_strides[0];
|
|
||||||
auto B_batch_stride = batch_strides[1];
|
|
||||||
auto C_batch_stride = batch_strides[2];
|
|
||||||
|
|
||||||
if (batch_shape.empty()) {
|
|
||||||
batch_shape.push_back(1);
|
|
||||||
A_batch_stride.push_back(0);
|
|
||||||
B_batch_stride.push_back(0);
|
|
||||||
C_batch_stride.push_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(
|
|
||||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<bool, int64_t, array> check_transpose(
|
std::tuple<bool, int64_t, array> check_transpose(
|
||||||
std::vector<array>& copies,
|
std::vector<array>& copies,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
|
Loading…
Reference in New Issue
Block a user