mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
9 Commits
f2adb5638d
...
jagrit06/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
400f8457ea | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa |
@@ -1,4 +1,5 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
sphinx-copybutton
|
||||
mlx
|
||||
|
||||
@@ -18,6 +18,7 @@ release = version
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
extensions = [
|
||||
"sphinx_copybutton",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
|
||||
@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
|
||||
@@ -24,6 +24,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/steel_gemm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
@@ -39,6 +40,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
@@ -52,10 +54,10 @@ target_sources(
|
||||
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||
else()
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__global__ void set_mm_device_pointers(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides a_batch_strides,
|
||||
const __grid_constant__ Strides b_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_ndim,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset] = elem_to_loc(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
batch_ndim);
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_addmm_device_pointers(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* c_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides a_batch_strides,
|
||||
const __grid_constant__ Strides b_batch_strides,
|
||||
const __grid_constant__ Strides c_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_ndim,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset, c_offset] = elem_to_loc(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
c_batch_strides.data(),
|
||||
batch_ndim);
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||
pointers[index + 3 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
||||
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
||||
&batch_mode,
|
||||
sizeof(batch_mode)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
||||
}
|
||||
|
||||
void Matmul::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides) {
|
||||
auto batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
|
||||
{static_cast<int>(batch_count * 3)},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
int block_size = 512;
|
||||
encoder.set_output_array(pointers);
|
||||
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
static_cast<int>(out.dtype().size()),
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
static_cast<int64_t>(M_) * N_,
|
||||
static_cast<int>(batch_shape.size()),
|
||||
batch_count);
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto out_pointers = b_pointers + batch_count;
|
||||
run_impl(
|
||||
encoder,
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void Matmul::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides,
|
||||
const mlx::core::Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
auto batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(c_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||
{static_cast<int>(batch_count * 4)},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
int block_size = 512;
|
||||
encoder.set_output_array(pointers);
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
static_cast<int>(out.dtype().size()),
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
const_param(c_batch_strides),
|
||||
static_cast<int64_t>(M_) * N_,
|
||||
static_cast<int>(batch_shape.size()),
|
||||
batch_count);
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto c_pointers = b_pointers + batch_count;
|
||||
auto out_pointers = c_pointers + batch_count;
|
||||
run_impl(
|
||||
encoder,
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
reinterpret_cast<void*>(c_pointers),
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -7,10 +7,12 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
struct CublasPreference {
|
||||
CublasPreference(Device& device) {
|
||||
CublasPreference(cu::Device& device) {
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
@@ -33,7 +35,7 @@ struct CublasPreference {
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
};
|
||||
|
||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
|
||||
static CublasPreference pref(device);
|
||||
return pref.pref_;
|
||||
}
|
||||
@@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
||||
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
|
||||
return CUDA_C_32F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
||||
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
||||
return desc;
|
||||
}
|
||||
|
||||
Matmul::Matmul(
|
||||
Device& device,
|
||||
} // namespace
|
||||
|
||||
CublasGemm::CublasGemm(
|
||||
cu::Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -155,8 +159,8 @@ Matmul::Matmul(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
}
|
||||
|
||||
Matmul::Matmul(
|
||||
Device& device,
|
||||
CublasGemm::CublasGemm(
|
||||
cu::Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -171,7 +175,7 @@ Matmul::Matmul(
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride,
|
||||
int64_t c_batch_stride)
|
||||
: Matmul(
|
||||
: CublasGemm(
|
||||
device,
|
||||
dtype,
|
||||
a_transposed,
|
||||
@@ -190,7 +194,7 @@ Matmul::Matmul(
|
||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
Matmul::~Matmul() {
|
||||
CublasGemm::~CublasGemm() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||
@@ -198,7 +202,73 @@ Matmul::~Matmul() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||
}
|
||||
|
||||
void Matmul::run_impl(
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
return;
|
||||
}
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
c_batch_strides,
|
||||
alpha,
|
||||
beta);
|
||||
return;
|
||||
}
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
c.data<void>(),
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
void CublasGemm::execute(
|
||||
cu::CommandEncoder& encoder,
|
||||
void* out,
|
||||
const void* a,
|
||||
@@ -256,29 +326,4 @@ void Matmul::run_impl(
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
void Matmul::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::optional<array>& c /* = std::nullopt */,
|
||||
float alpha /* = 1 */,
|
||||
float beta /* = 0 */) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
if (c) {
|
||||
encoder.set_input_array(*c);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
|
||||
run_impl(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
c ? c->data<void>() : nullptr,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -5,13 +5,13 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <optional>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
class Matmul {
|
||||
namespace mlx::core {
|
||||
|
||||
class CublasGemm {
|
||||
public:
|
||||
Matmul(
|
||||
Device& device,
|
||||
CublasGemm(
|
||||
cu::Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -25,8 +25,8 @@ class Matmul {
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride);
|
||||
|
||||
Matmul(
|
||||
Device& device,
|
||||
CublasGemm(
|
||||
cu::Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -42,25 +42,39 @@ class Matmul {
|
||||
int64_t b_batch_stride,
|
||||
int64_t c_batch_stride);
|
||||
|
||||
~Matmul();
|
||||
~CublasGemm();
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::optional<array>& c = std::nullopt,
|
||||
float alpha = 1,
|
||||
float beta = 0);
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta);
|
||||
|
||||
private:
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides);
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
@@ -68,15 +82,14 @@ class Matmul {
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides,
|
||||
const mlx::core::Strides& c_batch_strides,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta);
|
||||
|
||||
private:
|
||||
void run_impl(
|
||||
void execute(
|
||||
cu::CommandEncoder& encoder,
|
||||
void* out,
|
||||
const void* a,
|
||||
@@ -97,4 +110,4 @@ class Matmul {
|
||||
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -4,16 +4,16 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
namespace mlx::core {
|
||||
|
||||
void Matmul::run_batched(
|
||||
void CublasGemm::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides) {
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
@@ -22,7 +22,7 @@ void Matmul::run_batched(
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
auto concurrent = encoder.concurrent_context();
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
run_impl(
|
||||
execute(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
@@ -33,16 +33,16 @@ void Matmul::run_batched(
|
||||
}
|
||||
}
|
||||
|
||||
void Matmul::run_batched(
|
||||
void CublasGemm::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides,
|
||||
const mlx::core::Strides& c_batch_strides,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
encoder.set_input_array(a);
|
||||
@@ -56,7 +56,7 @@ void Matmul::run_batched(
|
||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||
auto concurrent = encoder.concurrent_context();
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
run_impl(
|
||||
execute(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
@@ -70,4 +70,4 @@ void Matmul::run_batched(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
} // namespace mlx::core
|
||||
327
mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu
Normal file
327
mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu
Normal file
@@ -0,0 +1,327 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <int NDIM>
|
||||
__global__ void set_mm_device_pointers_nd(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data());
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_mm_device_pointers_g(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides a_batch_strides,
|
||||
const __grid_constant__ Strides b_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_ndim,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset] = elem_to_loc(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
batch_ndim);
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
__global__ void set_addmm_device_pointers_nd(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* c_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
c_batch_strides.data());
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||
pointers[index + 3 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_addmm_device_pointers_g(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* c_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides a_batch_strides,
|
||||
const __grid_constant__ Strides b_batch_strides,
|
||||
const __grid_constant__ Strides c_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_ndim,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset, c_offset] = elem_to_loc(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
c_batch_strides.data(),
|
||||
batch_ndim);
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||
pointers[index + 3 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace {
|
||||
|
||||
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
||||
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
||||
&batch_mode,
|
||||
sizeof(batch_mode)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(void*) * 3),
|
||||
{batch_count * 3},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
encoder.set_output_array(pointers);
|
||||
|
||||
int block_dims = std::min(batch_count, 256);
|
||||
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||
int64_t batch_stride = M_ * N_;
|
||||
int item_size = out.itemsize();
|
||||
|
||||
int ndim = batch_shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers_nd<ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param<ndim_constant()>(batch_shape),
|
||||
const_param<ndim_constant()>(a_batch_strides),
|
||||
const_param<ndim_constant()>(b_batch_strides),
|
||||
batch_stride,
|
||||
batch_count);
|
||||
});
|
||||
} else {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers_g,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
batch_stride,
|
||||
ndim,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto out_pointers = b_pointers + batch_count;
|
||||
execute(
|
||||
encoder,
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(c_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||
{batch_count * 4},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
encoder.set_output_array(pointers);
|
||||
|
||||
int block_dims = std::min(batch_count, 256);
|
||||
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||
int64_t batch_stride = M_ * N_;
|
||||
int item_size = out.itemsize();
|
||||
|
||||
int ndim = batch_shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers_nd<ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param<ndim_constant()>(batch_shape),
|
||||
const_param<ndim_constant()>(a_batch_strides),
|
||||
const_param<ndim_constant()>(b_batch_strides),
|
||||
const_param<ndim_constant()>(c_batch_strides),
|
||||
batch_stride,
|
||||
batch_count);
|
||||
});
|
||||
} else {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers_g,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
const_param(c_batch_strides),
|
||||
batch_stride,
|
||||
ndim,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto c_pointers = b_pointers + batch_count;
|
||||
auto out_pointers = c_pointers + batch_count;
|
||||
execute(
|
||||
encoder,
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
reinterpret_cast<void*>(c_pointers),
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
301
mlx/backend/cuda/gemms/steel_gemm.cu
Normal file
301
mlx/backend/cuda/gemms/steel_gemm.cu
Normal file
@@ -0,0 +1,301 @@
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/gemms/steel_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <numeric>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#include "mlx/backend/cuda/steel/gemm.cuh"
|
||||
#include "mlx/backend/cuda/steel/mma.cuh"
|
||||
#include "mlx/backend/cuda/steel/tiles.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct GemmParams {
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int lda;
|
||||
int ldb;
|
||||
int ldd;
|
||||
|
||||
int NblockM;
|
||||
int NblockN;
|
||||
int NblockK;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int SL,
|
||||
int Nstages>
|
||||
__global__ void kernel_steel_gemm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* d,
|
||||
__grid_constant__ const GemmParams params) {
|
||||
const int bM_idx = (blockIdx.y << SL) + (blockIdx.x & ((1 << SL) - 1));
|
||||
const int bN_idx = blockIdx.x >> SL;
|
||||
|
||||
if (params.NblockN <= bN_idx || params.NblockM <= bM_idx) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int d_row = bM_idx * BM;
|
||||
const int d_col = bN_idx * BN;
|
||||
const size_t d_row_long = size_t(d_row);
|
||||
const size_t d_col_long = size_t(d_col);
|
||||
|
||||
a += transpose_a ? d_row_long : d_row_long * params.K;
|
||||
b += transpose_b ? d_col_long * params.K : d_col_long;
|
||||
d += d_row_long * params.ldd + d_col_long;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
|
||||
const int lane_idx = warp.thread_rank();
|
||||
const int warp_idx = warp.meta_group_rank();
|
||||
|
||||
const int wm = warp_idx / WN;
|
||||
const int wn = warp_idx % WN;
|
||||
|
||||
constexpr int SM = BM / WM;
|
||||
constexpr int SN = BN / WN;
|
||||
constexpr int SK = BK;
|
||||
constexpr int TK = SK / 16;
|
||||
|
||||
constexpr int NUM_WARPS = WM * WN;
|
||||
|
||||
// Allocate shared memory
|
||||
extern __shared__ char shmem[];
|
||||
SharedTile<T, BM, BK>(&as)[Nstages] =
|
||||
*(SharedTile<T, BM, BK>(*)[Nstages])(&shmem[0]);
|
||||
SharedTile<T, BN, BK>(&bs)[Nstages] = *(SharedTile<T, BN, BK>(*)[Nstages])(
|
||||
&shmem[sizeof(T) * Nstages * BM * BK]);
|
||||
|
||||
// Allocate registers for the MMA
|
||||
RegisterTile<float, SM, SN> C;
|
||||
RegisterTile<T, SM, 16> A[TK];
|
||||
RegisterTile<T, SN, 16> B[TK];
|
||||
|
||||
// Zero the accumulators
|
||||
C.fill(0);
|
||||
|
||||
// Start gmem -> smem copies
|
||||
int k_block_read = 0;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int bk = 0; bk < (Nstages - 1); bk++) {
|
||||
load_async<NUM_WARPS>(
|
||||
as[bk], as[bk].base_addr(), a + k_block_read, params.K);
|
||||
load_async<NUM_WARPS>(
|
||||
bs[bk], bs[bk].base_addr(), b + k_block_read, params.K);
|
||||
k_block_read += BK;
|
||||
cp_async_commit();
|
||||
}
|
||||
|
||||
int smem_pipe_read = 0;
|
||||
int smem_pipe_write = Nstages - 1;
|
||||
|
||||
// Wait till only 1 remains laoding
|
||||
cp_async_wait<1>();
|
||||
block.sync();
|
||||
|
||||
const int offset_m = wm * SM;
|
||||
const int offset_n = wn * SN;
|
||||
|
||||
// Start smem -> register copy
|
||||
A[0].load(
|
||||
as[smem_pipe_read],
|
||||
as[smem_pipe_read].base_addr(),
|
||||
offset_m + lane_idx % 16,
|
||||
lane_idx / 16 * 8);
|
||||
B[0].load(
|
||||
bs[smem_pipe_read],
|
||||
bs[smem_pipe_read].base_addr(),
|
||||
offset_n + lane_idx % 16,
|
||||
lane_idx / 16 * 8);
|
||||
|
||||
// Main loop
|
||||
for (int kb = 0; kb < params.NblockK; kb++) {
|
||||
// Prepare next registers
|
||||
{
|
||||
A[1].load(
|
||||
as[smem_pipe_read],
|
||||
as[smem_pipe_read].base_addr(),
|
||||
offset_m + lane_idx % 16,
|
||||
16 + lane_idx / 16 * 8);
|
||||
B[1].load(
|
||||
bs[smem_pipe_read],
|
||||
bs[smem_pipe_read].base_addr(),
|
||||
offset_n + lane_idx % 16,
|
||||
16 + lane_idx / 16 * 8);
|
||||
}
|
||||
|
||||
// Prepare next smem
|
||||
if ((kb + Nstages - 1) < params.NblockK) {
|
||||
load_async<NUM_WARPS>(
|
||||
as[smem_pipe_write],
|
||||
as[smem_pipe_write].base_addr(),
|
||||
a + k_block_read,
|
||||
params.K);
|
||||
load_async<NUM_WARPS>(
|
||||
bs[smem_pipe_write],
|
||||
bs[smem_pipe_write].base_addr(),
|
||||
b + k_block_read,
|
||||
params.K);
|
||||
}
|
||||
k_block_read += BK;
|
||||
|
||||
cp_async_commit();
|
||||
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
smem_pipe_read = smem_pipe_read + 1;
|
||||
smem_pipe_read = (smem_pipe_read == Nstages) ? 0 : smem_pipe_read;
|
||||
|
||||
// Do current gemm
|
||||
mma_t(C, A[0], B[0]);
|
||||
|
||||
// Do wait for next register
|
||||
cp_async_wait<1>();
|
||||
block.sync();
|
||||
|
||||
// Prepare next register (smem_pipe_read has moved to the next)
|
||||
{
|
||||
A[0].load(
|
||||
as[smem_pipe_read],
|
||||
as[smem_pipe_read].base_addr(),
|
||||
offset_m + lane_idx % 16,
|
||||
lane_idx / 16 * 8);
|
||||
B[0].load(
|
||||
bs[smem_pipe_read],
|
||||
bs[smem_pipe_read].base_addr(),
|
||||
offset_n + lane_idx % 16,
|
||||
lane_idx / 16 * 8);
|
||||
}
|
||||
|
||||
// Do current gemm
|
||||
mma_t(C, A[1], B[1]);
|
||||
}
|
||||
|
||||
// Wait and clear
|
||||
cp_async_wait_all();
|
||||
block.sync();
|
||||
|
||||
C.store_global(d, params.ldd, offset_m, offset_n);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void dispatch_steel_gemm(
|
||||
const Stream& s,
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& d,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool a_transposed,
|
||||
bool b_transposed) {
|
||||
using DataType = cuda_type_t<float16_t>;
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(d);
|
||||
|
||||
constexpr int BM = 128;
|
||||
constexpr int BN = 128;
|
||||
constexpr int BK = 32;
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
|
||||
constexpr int SL = 0;
|
||||
constexpr int Nstages = 3;
|
||||
|
||||
constexpr uint32_t smem_bytes = BK * (BM + BN) * Nstages * sizeof(DataType);
|
||||
|
||||
const int NblockM = (M + BM - 1) / BM;
|
||||
const int NblockN = (N + BN - 1) / BN;
|
||||
const int NblockK = (K + BK - 1) / BK;
|
||||
|
||||
cu::GemmParams params{
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
|
||||
/* int NblockM = */ NblockM,
|
||||
/* int NblockN = */ NblockN,
|
||||
/* int NblockK = */ NblockK,
|
||||
};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << SL;
|
||||
int tm = (NblockM + tile - 1) / tile;
|
||||
int tn = NblockN * tile;
|
||||
|
||||
dim3 grid_dim(tn, tm, 1);
|
||||
dim3 block_dim(32 * WM * WN, 1, 1);
|
||||
|
||||
dispatch_bool(a_transposed, [&](auto ta_) {
|
||||
dispatch_bool(b_transposed, [&](auto tb_) {
|
||||
constexpr bool ta = ta_.value;
|
||||
constexpr bool tb = tb_.value;
|
||||
|
||||
auto kernel = cu::ab_t_aligned<DataType, BM, BN, BK>;
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
smem_bytes,
|
||||
a.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
d.data<DataType>(),
|
||||
N,
|
||||
K);
|
||||
|
||||
// auto kernel = cu::kernel_steel_gemm<DataType, BM, BN, BK, WM, WN, ta,
|
||||
// tb, SL, Nstages>;
|
||||
|
||||
// cudaFuncSetAttribute(kernel,
|
||||
// cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||
|
||||
// encoder.add_kernel_node(
|
||||
// kernel,
|
||||
// grid_dim,
|
||||
// block_dim,
|
||||
// smem_bytes,
|
||||
// a.data<DataType>(),
|
||||
// b.data<DataType>(),
|
||||
// d.data<DataType>(),
|
||||
// params);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
27
mlx/backend/cuda/gemms/steel_gemm.h
Normal file
27
mlx/backend/cuda/gemms/steel_gemm.h
Normal file
@@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void dispatch_steel_gemm(
|
||||
const Stream& s,
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& d,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool a_transposed,
|
||||
bool b_transposed);
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/cuda/gemms/steel_gemm.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <numeric>
|
||||
|
||||
@@ -95,9 +97,27 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (out.dtype() == float16 && batch_count == 1 && !a_transposed &&
|
||||
b_transposed) {
|
||||
return dispatch_steel_gemm(
|
||||
/* const Stream& s = */ s,
|
||||
/* cu::CommandEncoder& encoder = */ encoder,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* array& d = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ N,
|
||||
/* bool a_transposed = */ a_transposed,
|
||||
/* bool b_transposed = */ b_transposed);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
cu::Matmul matmul(
|
||||
CublasGemm gemm(
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
@@ -111,14 +131,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
|
||||
if ((batch_count / batch_shape.back()) == 1) {
|
||||
matmul.run(encoder, out, a, b);
|
||||
return;
|
||||
}
|
||||
|
||||
matmul.run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -186,7 +199,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
|
||||
cu::Matmul matmul(
|
||||
CublasGemm gemm(
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
@@ -202,12 +215,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back(),
|
||||
c_batch_strides.back());
|
||||
|
||||
if ((batch_count / batch_shape.back()) == 1) {
|
||||
matmul.run(encoder, out, a, b, c, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
matmul.run_batched(
|
||||
gemm.run(
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
|
||||
@@ -6,17 +6,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
bool fast::ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
|
||||
781
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
781
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
@@ -0,0 +1,781 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
#define PRAGMA_LOOP_UNROLL #pragma unroll
|
||||
|
||||
struct AttnParams {
|
||||
int B;
|
||||
int H;
|
||||
int D;
|
||||
|
||||
int qL;
|
||||
int kL;
|
||||
|
||||
int gqa_factor;
|
||||
float scale;
|
||||
|
||||
int64_t Q_strides[3];
|
||||
int64_t K_strides[3];
|
||||
int64_t V_strides[3];
|
||||
int64_t O_strides[3];
|
||||
};
|
||||
|
||||
template <typename T, bool do_causal, int D>
|
||||
__global__ void kernel_sdpav_1pass(
|
||||
const T* Q,
|
||||
const T* K,
|
||||
const T* V,
|
||||
T* O,
|
||||
__grid_constant__ const AttnParams params) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
|
||||
constexpr int v_per_thread = D / BD;
|
||||
|
||||
const int inner_k_stride = BN * int(params.K_strides[2]);
|
||||
const int inner_v_stride = BN * int(params.V_strides[2]);
|
||||
|
||||
typedef float U;
|
||||
|
||||
U q[v_per_thread];
|
||||
U k[v_per_thread];
|
||||
U o[v_per_thread];
|
||||
|
||||
__shared__ U outputs[BN][BD + 1];
|
||||
__shared__ U max_scores[BN];
|
||||
__shared__ U sum_exp_scores[BN];
|
||||
|
||||
const U scale_log2 = params.scale * 1.44269504089f;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
|
||||
const int lane_idx = warp.thread_rank();
|
||||
const int warp_idx = warp.meta_group_rank();
|
||||
|
||||
// Adjust to thread block and thread
|
||||
const int batch_idx = blockIdx.z;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||
|
||||
const int q_seq_idx = blockIdx.y;
|
||||
const int kv_seq_idx = warp_idx;
|
||||
|
||||
Q += batch_idx * params.Q_strides[0] + // Batch
|
||||
head_idx * params.Q_strides[1] + // Head
|
||||
q_seq_idx * params.Q_strides[2]; // Sequence
|
||||
|
||||
K += batch_idx * params.K_strides[0] + // Batch
|
||||
kv_head_idx * params.K_strides[1] + // Head
|
||||
kv_seq_idx * params.K_strides[2]; // Sequence
|
||||
|
||||
V += batch_idx * params.V_strides[0] + // Batch
|
||||
kv_head_idx * params.V_strides[1] + // Head
|
||||
kv_seq_idx * params.V_strides[2]; // Sequence
|
||||
|
||||
O += batch_idx * params.O_strides[0] + // Batch
|
||||
head_idx * params.O_strides[1] + // Head
|
||||
q_seq_idx * params.O_strides[2]; // Sequence
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
|
||||
}
|
||||
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0.f;
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += BN) {
|
||||
bool use_key = true;
|
||||
if constexpr (do_causal) {
|
||||
use_key = i <= (params.kL - params.qL + q_seq_idx);
|
||||
}
|
||||
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 0; j < v_per_thread; j++) {
|
||||
k[j] = K[v_per_thread * lane_idx + j];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0.f;
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 0; j < v_per_thread; j++) {
|
||||
score += q[j] * k[j];
|
||||
}
|
||||
|
||||
// Warp sum
|
||||
score = cg::reduce(warp, score, cg::plus<U>());
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 0; j < v_per_thread; j++) {
|
||||
o[j] = o[j] * factor +
|
||||
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
K += inner_k_stride;
|
||||
V += inner_v_stride;
|
||||
}
|
||||
|
||||
if (lane_idx == 0) {
|
||||
max_scores[warp_idx] = max_score;
|
||||
sum_exp_scores[warp_idx] = sum_exp_score;
|
||||
}
|
||||
block.sync();
|
||||
|
||||
max_score = max_scores[lane_idx];
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = exp2f(max_score - new_max);
|
||||
sum_exp_score =
|
||||
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
|
||||
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
outputs[lane_idx][warp_idx] = o[i];
|
||||
block.sync();
|
||||
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||
block.sync();
|
||||
}
|
||||
|
||||
// And write the output
|
||||
if (lane_idx == 0) {
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool do_causal, int D>
|
||||
__global__ void kernel_sdpav_2pass_1(
|
||||
const T* Q,
|
||||
const T* K,
|
||||
const T* V,
|
||||
float* partials,
|
||||
float* sums,
|
||||
float* maxs,
|
||||
__grid_constant__ const AttnParams params) {
|
||||
constexpr int BN = 8;
|
||||
constexpr int BD = 32;
|
||||
constexpr int blocks = 32;
|
||||
|
||||
constexpr int v_per_thread = D / BD;
|
||||
|
||||
const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
|
||||
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
|
||||
|
||||
typedef float U;
|
||||
|
||||
U q[v_per_thread];
|
||||
U k[v_per_thread];
|
||||
U o[v_per_thread];
|
||||
|
||||
__shared__ U outputs[BN][BD + 1];
|
||||
__shared__ U max_scores[BN];
|
||||
__shared__ U sum_exp_scores[BN];
|
||||
|
||||
const U scale_log2 = params.scale * 1.44269504089f;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
|
||||
const int lane_idx = warp.thread_rank();
|
||||
const int warp_idx = warp.meta_group_rank();
|
||||
|
||||
// Adjust to thread block and thread
|
||||
const int batch_idx = blockIdx.z / blocks;
|
||||
const int block_idx = blockIdx.z % blocks;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int kv_head_idx = head_idx / params.gqa_factor;
|
||||
|
||||
const int q_seq_idx = blockIdx.y;
|
||||
const int kv_seq_idx = block_idx * BN + warp_idx;
|
||||
|
||||
Q += batch_idx * params.Q_strides[0] + // Batch
|
||||
head_idx * params.Q_strides[1] + // Head
|
||||
q_seq_idx * params.Q_strides[2]; // Sequence
|
||||
|
||||
K += batch_idx * params.K_strides[0] + // Batch
|
||||
kv_head_idx * params.K_strides[1] + // Head
|
||||
kv_seq_idx * params.K_strides[2]; // Sequence
|
||||
|
||||
V += batch_idx * params.V_strides[0] + // Batch
|
||||
kv_head_idx * params.V_strides[1] + // Head
|
||||
kv_seq_idx * params.V_strides[2]; // Sequence
|
||||
|
||||
const int p_stride_s = blocks;
|
||||
const int p_stride_h = params.qL * p_stride_s;
|
||||
const int p_stride_b = params.H * p_stride_h;
|
||||
const int p_offset = batch_idx * p_stride_b + // Batch
|
||||
head_idx * p_stride_h + // Head
|
||||
q_seq_idx * p_stride_s + // Sequence
|
||||
block_idx; // Block
|
||||
|
||||
partials += p_offset * D;
|
||||
sums += p_offset;
|
||||
maxs += p_offset;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
|
||||
}
|
||||
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U sum_exp_score = 0.f;
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
|
||||
bool use_key = true;
|
||||
if constexpr (do_causal) {
|
||||
use_key = i <= (params.kL - params.qL + q_seq_idx);
|
||||
}
|
||||
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 0; j < v_per_thread; j++) {
|
||||
k[j] = K[v_per_thread * lane_idx + j];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0.f;
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 0; j < v_per_thread; j++) {
|
||||
score += q[j] * k[j];
|
||||
}
|
||||
|
||||
// Warp sum
|
||||
score = cg::reduce(warp, score, cg::plus<U>());
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U exp_score = exp2f(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 0; j < v_per_thread; j++) {
|
||||
o[j] = o[j] * factor +
|
||||
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
K += inner_k_stride;
|
||||
V += inner_v_stride;
|
||||
}
|
||||
|
||||
if (lane_idx == 0) {
|
||||
max_scores[warp_idx] = max_score;
|
||||
sum_exp_scores[warp_idx] = sum_exp_score;
|
||||
}
|
||||
|
||||
block.sync();
|
||||
|
||||
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = exp2f(max_score - new_max);
|
||||
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
|
||||
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
|
||||
|
||||
// Write the sum and new max
|
||||
if (warp_idx == 0) {
|
||||
sums[0] = sum_exp_score;
|
||||
maxs[0] = new_max;
|
||||
}
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
auto ff = exp2f(max_scores[warp_idx] - new_max);
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
outputs[warp_idx][lane_idx] = o[i] * ff;
|
||||
block.sync();
|
||||
|
||||
if (warp_idx == 0) {
|
||||
U ot = outputs[0][lane_idx];
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int j = 1; j < BN; j++) {
|
||||
ot += outputs[j][lane_idx];
|
||||
warp.sync();
|
||||
}
|
||||
o[i] = ot;
|
||||
}
|
||||
block.sync();
|
||||
}
|
||||
|
||||
if (warp_idx == 0) {
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
partials[v_per_thread * lane_idx + i] = o[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool do_causal, int D>
|
||||
__global__ void kernel_sdpav_2pass_2(
|
||||
const float* partials,
|
||||
const float* sums,
|
||||
const float* maxs,
|
||||
T* O,
|
||||
__grid_constant__ const AttnParams params) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
constexpr int blocks = 32;
|
||||
|
||||
constexpr int v_per_thread = D / BD;
|
||||
|
||||
typedef float U;
|
||||
|
||||
U o[v_per_thread];
|
||||
__shared__ U outputs[BN][BD + 1];
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
|
||||
const int lane_idx = warp.thread_rank();
|
||||
const int warp_idx = warp.meta_group_rank();
|
||||
|
||||
// Adjust to thread block and thread
|
||||
const int batch_idx = blockIdx.z;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int q_seq_idx = blockIdx.y;
|
||||
|
||||
const int p_stride_s = blocks;
|
||||
const int p_stride_h = params.qL * p_stride_s;
|
||||
const int p_stride_b = params.H * p_stride_h;
|
||||
const int p_offset = batch_idx * p_stride_b + // Batch
|
||||
head_idx * p_stride_h + // Head
|
||||
q_seq_idx * p_stride_s; // Sequence
|
||||
|
||||
partials += p_offset * D + warp_idx * D;
|
||||
sums += p_offset;
|
||||
maxs += p_offset;
|
||||
|
||||
O += batch_idx * params.O_strides[0] + // Batch
|
||||
head_idx * params.O_strides[1] + // Head
|
||||
q_seq_idx * params.O_strides[2]; // Sequence
|
||||
|
||||
U max_score = maxs[lane_idx];
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
o[i] = partials[v_per_thread * lane_idx + i];
|
||||
}
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
outputs[lane_idx][warp_idx] = o[i];
|
||||
block.sync();
|
||||
U ot = outputs[warp_idx][lane_idx] * factor;
|
||||
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
|
||||
block.sync();
|
||||
}
|
||||
|
||||
// And write the output
|
||||
if (lane_idx == 0) {
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename F>
|
||||
void dispatch_headdim(int n, F&& f) {
|
||||
switch (n) {
|
||||
case 64:
|
||||
f(std::integral_constant<int, 64>{});
|
||||
break;
|
||||
case 96:
|
||||
f(std::integral_constant<int, 96>{});
|
||||
break;
|
||||
case 128:
|
||||
f(std::integral_constant<int, 128>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void sdpa_vector_1pass_fallback(
|
||||
const Stream& s,
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
cu::AttnParams params{
|
||||
/* int B = */ q.shape(0),
|
||||
/* int H = */ q.shape(1),
|
||||
/* int D = */ q.shape(3),
|
||||
|
||||
/* int qL = */ q.shape(2),
|
||||
/* int kL = */ k.shape(2),
|
||||
|
||||
/* int gqa_factor = */ q.shape(1) / k.shape(1),
|
||||
/* float scale = */ scale,
|
||||
|
||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
|
||||
dim3 grid_dim(params.H, params.qL, params.B);
|
||||
dim3 block_dim(1024, 1, 1);
|
||||
|
||||
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
auto kernel =
|
||||
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
o.data<DataType>(),
|
||||
params);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void sdpa_vector_2pass_fallback(
|
||||
const Stream& s,
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
cu::AttnParams params{
|
||||
/* int B = */ q.shape(0),
|
||||
/* int H = */ q.shape(1),
|
||||
/* int D = */ q.shape(3),
|
||||
|
||||
/* int qL = */ q.shape(2),
|
||||
/* int kL = */ k.shape(2),
|
||||
|
||||
/* int gqa_factor = */ q.shape(1) / k.shape(1),
|
||||
/* float scale = */ scale,
|
||||
|
||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
|
||||
// Allocate the intermediates
|
||||
int blocks = 32;
|
||||
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(o.ndim() + 1);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
|
||||
intermediate_shape.push_back(blocks);
|
||||
intermediate_shape.push_back(o.shape().back());
|
||||
|
||||
array intermediate(intermediate_shape, float32, nullptr, {});
|
||||
intermediate_shape.pop_back();
|
||||
array sums(intermediate_shape, float32, nullptr, {});
|
||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
sums.set_data(allocator::malloc(sums.nbytes()));
|
||||
maxs.set_data(allocator::malloc(maxs.nbytes()));
|
||||
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.add_temporary(sums);
|
||||
encoder.add_temporary(maxs);
|
||||
|
||||
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
{
|
||||
auto kernel = cu::
|
||||
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.set_output_array(sums);
|
||||
encoder.set_output_array(maxs);
|
||||
|
||||
dim3 grid_dim(params.H, params.qL, params.B * 32);
|
||||
dim3 block_dim(8 * 32, 1, 1);
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
params);
|
||||
}
|
||||
|
||||
{
|
||||
auto kernel = cu::
|
||||
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
|
||||
|
||||
encoder.set_input_array(intermediate);
|
||||
encoder.set_input_array(sums);
|
||||
encoder.set_input_array(maxs);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
dim3 grid_dim(params.H, params.qL, params.B);
|
||||
dim3 block_dim(1024, 1, 1);
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
o.data<DataType>(),
|
||||
params);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void sdpa_vector_fallback(
|
||||
const Stream& s,
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
int kL = k.shape(2);
|
||||
|
||||
if (kL > 1024) {
|
||||
return sdpa_vector_2pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
} else {
|
||||
return sdpa_vector_1pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
const int key_sequence_length = k.shape(2);
|
||||
|
||||
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||
|
||||
const bool supported_vector_config =
|
||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||
|
||||
const bool supported_config = supported_vector_config;
|
||||
|
||||
return has_arr_mask || !supported_config;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(3);
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||
copies.push_back(std::move(arr_copy));
|
||||
return copies.back();
|
||||
} else {
|
||||
return arr;
|
||||
}
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) < 4) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return true;
|
||||
}
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
if (shape[0] == 1 || shape[1] == 1) {
|
||||
// If either the batch or head dimension is a singleton, the other can
|
||||
// be transposed with the sequence dimension
|
||||
auto bidx = shape[0] == 1 ? 1 : 0;
|
||||
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
|
||||
(strides[bidx] == shape[3]);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto kv_copy_unless = [](const array& arr) {
|
||||
// keys and values should be copied if:
|
||||
// - the last dimension is not contiguous
|
||||
// - the batch and head dim are not contiguous
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
if (strides.back() != 1) {
|
||||
return false;
|
||||
}
|
||||
if (shape[0] == 1 || shape[1] == 1) {
|
||||
return true;
|
||||
}
|
||||
return (strides[0] == strides[1] * shape[1]);
|
||||
};
|
||||
|
||||
const auto& q = copy_unless(q_copy_unless, q_pre);
|
||||
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
int64_t str_oD = 1;
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ o.shape(2) == 1,
|
||||
/* bool col_contiguous = */ 0,
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
}
|
||||
|
||||
// Full attention mode should never reach here
|
||||
else {
|
||||
throw std::runtime_error("Doesn't support matrix yet.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -143,85 +143,87 @@ struct Tile16x16 {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A simple container of multiple Tile16x16.
|
||||
*
|
||||
* Provides utility functions for loading and manipulating collections of basic
|
||||
* tiles.
|
||||
*/
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
struct RegisterTile {
|
||||
static constexpr int ROWS = ROWS_;
|
||||
static constexpr int COLS = COLS_;
|
||||
static constexpr int TILES_X = COLS / 16;
|
||||
static constexpr int TILES_Y = ROWS / 16;
|
||||
// /**
|
||||
// * A simple container of multiple Tile16x16.
|
||||
// *
|
||||
// * Provides utility functions for loading and manipulating collections of
|
||||
// basic
|
||||
// * tiles.
|
||||
// */
|
||||
// template <typename T, int ROWS_, int COLS_>
|
||||
// struct RegisterTile {
|
||||
// static constexpr int ROWS = ROWS_;
|
||||
// static constexpr int COLS = COLS_;
|
||||
// static constexpr int TILES_X = COLS / 16;
|
||||
// static constexpr int TILES_Y = ROWS / 16;
|
||||
|
||||
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||
// Tile16x16<T> data[TILES_X * TILES_Y];
|
||||
|
||||
__device__ inline void fill(T v) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].fill(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
// __device__ inline void fill(T v) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// data[i * TILES_X + j].fill(v);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename Tile>
|
||||
__device__ __forceinline__ void
|
||||
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].load(
|
||||
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||
}
|
||||
}
|
||||
}
|
||||
// template <typename Tile>
|
||||
// __device__ __forceinline__ void
|
||||
// load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// data[i * TILES_X + j].load(
|
||||
// tile.loc(base_address, row + i * 16, col + j * 16));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename Tile, typename F>
|
||||
__device__ __forceinline__ void
|
||||
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
f(data[i * TILES_X + j],
|
||||
tile,
|
||||
base_address,
|
||||
row + i * 16,
|
||||
col + j * 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
// template <typename Tile, typename F>
|
||||
// __device__ __forceinline__ void
|
||||
// load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// f(data[i * TILES_X + j],
|
||||
// tile,
|
||||
// base_address,
|
||||
// row + i * 16,
|
||||
// col + j * 16);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global(
|
||||
x + (row + i * 16) * N + col + j * 16, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
// template <typename U>
|
||||
// __device__ inline void store_global(U* x, int N, int row, int col) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// data[i * TILES_X + j].store_global(
|
||||
// x + (row + i * 16) * N + col + j * 16, N);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void
|
||||
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global_safe(
|
||||
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
// template <typename U>
|
||||
// __device__ inline void
|
||||
// store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// data[i * TILES_X + j].store_global_safe(
|
||||
// x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i *
|
||||
// 16);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
|
||||
/**
|
||||
* A simple container of multiple Tile16x16.
|
||||
|
||||
@@ -134,6 +134,10 @@ instantiate_and_or(and, And)
|
||||
instantiate_and_or(or, Or)
|
||||
|
||||
#define instantiate_sum_prod(name, op) \
|
||||
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
|
||||
@@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
|
||||
const std::string& op_name) {
|
||||
if (op_name == "sum" || op_name == "prod") {
|
||||
if (issubdtype(in.dtype(), integer)) {
|
||||
switch (in.dtype().size()) {
|
||||
case 1:
|
||||
switch (in.dtype()) {
|
||||
case uint8:
|
||||
return {uint8, uint32};
|
||||
case uint16:
|
||||
return {uint16, uint32};
|
||||
case uint32:
|
||||
return {uint32, uint32};
|
||||
case uint64:
|
||||
return {uint64, uint64};
|
||||
case int8:
|
||||
return {int8, int32};
|
||||
case 2:
|
||||
case int16:
|
||||
return {int16, int32};
|
||||
case 4:
|
||||
case int32:
|
||||
return {int32, int32};
|
||||
case 8:
|
||||
case int64:
|
||||
return {int64, int64};
|
||||
default:
|
||||
throw std::runtime_error("Unsupported integer type");
|
||||
}
|
||||
}
|
||||
if (in.dtype() == bool_) {
|
||||
|
||||
31
mlx/ops.cpp
31
mlx/ops.cpp
@@ -2381,9 +2381,20 @@ array logsumexp(
|
||||
throw std::invalid_argument(
|
||||
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
||||
}
|
||||
bool reduce_last_dim =
|
||||
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||
if (reduce_last_dim) {
|
||||
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||
// is [1, 1, ..., N].
|
||||
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||
reduce_last_dim = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||
if (!is_complex && axes.size() == 1 &&
|
||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
if (!is_complex && reduce_last_dim) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto out_shape = a.shape();
|
||||
out_shape.back() = 1;
|
||||
@@ -3403,10 +3414,20 @@ array softmax(
|
||||
throw std::invalid_argument(
|
||||
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
||||
}
|
||||
|
||||
bool reduce_last_dim =
|
||||
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||
if (reduce_last_dim) {
|
||||
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||
// is [1, 1, ..., N].
|
||||
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||
reduce_last_dim = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||
if (!is_complex && axes.size() == 1 &&
|
||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
if (!is_complex && reduce_last_dim) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
a.shape(),
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#pragma once
|
||||
|
||||
#define MLX_VERSION_MAJOR 0
|
||||
#define MLX_VERSION_MINOR 27
|
||||
#define MLX_VERSION_PATCH 1
|
||||
#define MLX_VERSION_MINOR 28
|
||||
#define MLX_VERSION_PATCH 0
|
||||
#define MLX_VERSION_NUMERIC \
|
||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
requires = [
|
||||
"setuptools>=80",
|
||||
"nanobind==2.4.0",
|
||||
"cmake>=3.25",
|
||||
"cmake>=3.25,<4.1",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
|
||||
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
||||
}
|
||||
|
||||
// sum and prod overflow
|
||||
{
|
||||
auto a = full({256, 2, 2}, 1u, uint8);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||
|
||||
a = full({65535, 2, 2}, 1u, uint16);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test gpu reduce with axes") {
|
||||
// reducing only some axes and irregular layouts
|
||||
{
|
||||
array a(1.0f);
|
||||
|
||||
@@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
|
||||
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
||||
}
|
||||
|
||||
// Test unsigned sum
|
||||
{
|
||||
const int num_elems = 1000;
|
||||
|
||||
auto x = astype(full({num_elems}, 255), uint8);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
|
||||
|
||||
x = astype(full({num_elems}, 65535), uint16);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
|
||||
|
||||
x = full({3, 3, 3}, 10000, uint32);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
|
||||
|
||||
x = full({3, 3, 3}, 10000, uint64);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
|
||||
}
|
||||
|
||||
// Test prod
|
||||
{
|
||||
auto x = array({});
|
||||
@@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
|
||||
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
||||
}
|
||||
|
||||
// Test unsigned prod
|
||||
{
|
||||
auto x = array({255, 255}, {2}, uint8);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
|
||||
|
||||
x = array({65535, 2}, {2}, uint16);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
|
||||
|
||||
x = array({100000, 2}, {2}, uint32);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
|
||||
|
||||
x = array({100000, 2}, {2}, uint64);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
|
||||
}
|
||||
|
||||
// Test all
|
||||
{
|
||||
auto x = array({});
|
||||
|
||||
Reference in New Issue
Block a user