mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
11 Commits
sdpav-back
...
jagrit06/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
400f8457ea | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 |
@@ -1,4 +1,5 @@
|
|||||||
sphinx
|
sphinx
|
||||||
breathe
|
breathe
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
|
sphinx-copybutton
|
||||||
mlx
|
mlx
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ release = version
|
|||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
|
"sphinx_copybutton",
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
|
|||||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
case uint8:
|
case uint8:
|
||||||
|
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
case uint16:
|
|
||||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
case uint32:
|
|
||||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
case uint64:
|
|
||||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/steel_gemm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
@@ -53,10 +54,10 @@ target_sources(
|
|||||||
|
|
||||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
|||||||
@@ -1,208 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
__global__ void set_mm_device_pointers(
|
|
||||||
int8_t** pointers,
|
|
||||||
int8_t* a_start,
|
|
||||||
int8_t* b_start,
|
|
||||||
int8_t* out_start,
|
|
||||||
int item_size,
|
|
||||||
const __grid_constant__ Shape batch_shape,
|
|
||||||
const __grid_constant__ Strides a_batch_strides,
|
|
||||||
const __grid_constant__ Strides b_batch_strides,
|
|
||||||
int64_t batch_stride,
|
|
||||||
int batch_ndim,
|
|
||||||
int batch_count) {
|
|
||||||
auto index = cg::this_grid().thread_rank();
|
|
||||||
if (index >= batch_count) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto [a_offset, b_offset] = elem_to_loc(
|
|
||||||
index,
|
|
||||||
batch_shape.data(),
|
|
||||||
a_batch_strides.data(),
|
|
||||||
b_batch_strides.data(),
|
|
||||||
batch_ndim);
|
|
||||||
pointers[index] = a_start + item_size * a_offset;
|
|
||||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
|
||||||
pointers[index + 2 * batch_count] =
|
|
||||||
out_start + item_size * index * batch_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void set_addmm_device_pointers(
|
|
||||||
int8_t** pointers,
|
|
||||||
int8_t* a_start,
|
|
||||||
int8_t* b_start,
|
|
||||||
int8_t* c_start,
|
|
||||||
int8_t* out_start,
|
|
||||||
int item_size,
|
|
||||||
const __grid_constant__ Shape batch_shape,
|
|
||||||
const __grid_constant__ Strides a_batch_strides,
|
|
||||||
const __grid_constant__ Strides b_batch_strides,
|
|
||||||
const __grid_constant__ Strides c_batch_strides,
|
|
||||||
int64_t batch_stride,
|
|
||||||
int batch_ndim,
|
|
||||||
int batch_count) {
|
|
||||||
auto index = cg::this_grid().thread_rank();
|
|
||||||
if (index >= batch_count) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto [a_offset, b_offset, c_offset] = elem_to_loc(
|
|
||||||
index,
|
|
||||||
batch_shape.data(),
|
|
||||||
a_batch_strides.data(),
|
|
||||||
b_batch_strides.data(),
|
|
||||||
c_batch_strides.data(),
|
|
||||||
batch_ndim);
|
|
||||||
pointers[index] = a_start + item_size * a_offset;
|
|
||||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
|
||||||
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
|
||||||
pointers[index + 3 * batch_count] =
|
|
||||||
out_start + item_size * index * batch_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
|
||||||
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
|
||||||
desc,
|
|
||||||
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
|
||||||
&batch_mode,
|
|
||||||
sizeof(batch_mode)));
|
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
|
||||||
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
|
||||||
}
|
|
||||||
|
|
||||||
void Matmul::run_batched(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
array& out,
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
const mlx::core::Shape& batch_shape,
|
|
||||||
const mlx::core::Strides& a_batch_strides,
|
|
||||||
const mlx::core::Strides& b_batch_strides) {
|
|
||||||
auto batch_count = out.size() / (M_ * N_);
|
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
|
||||||
set_pointer_mode(out_desc_, batch_count);
|
|
||||||
|
|
||||||
// Launch kernel to set device offsets
|
|
||||||
auto pointers = array(
|
|
||||||
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
|
|
||||||
{static_cast<int>(batch_count * 3)},
|
|
||||||
uint64);
|
|
||||||
|
|
||||||
encoder.add_temporary(pointers);
|
|
||||||
int block_size = 512;
|
|
||||||
encoder.set_output_array(pointers);
|
|
||||||
|
|
||||||
encoder.add_kernel_node(
|
|
||||||
cu::set_mm_device_pointers,
|
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
|
||||||
block_size,
|
|
||||||
0,
|
|
||||||
pointers.data<int8_t*>(),
|
|
||||||
a.data<int8_t>(),
|
|
||||||
b.data<int8_t>(),
|
|
||||||
out.data<int8_t>(),
|
|
||||||
static_cast<int>(out.dtype().size()),
|
|
||||||
const_param(batch_shape),
|
|
||||||
const_param(a_batch_strides),
|
|
||||||
const_param(b_batch_strides),
|
|
||||||
static_cast<int64_t>(M_) * N_,
|
|
||||||
static_cast<int>(batch_shape.size()),
|
|
||||||
batch_count);
|
|
||||||
|
|
||||||
// Run matmul
|
|
||||||
encoder.set_input_array(pointers);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
auto a_pointers = pointers.data<int8_t*>();
|
|
||||||
auto b_pointers = a_pointers + batch_count;
|
|
||||||
auto out_pointers = b_pointers + batch_count;
|
|
||||||
run_impl(
|
|
||||||
encoder,
|
|
||||||
reinterpret_cast<void*>(out_pointers),
|
|
||||||
reinterpret_cast<void*>(a_pointers),
|
|
||||||
reinterpret_cast<void*>(b_pointers),
|
|
||||||
nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Matmul::run_batched(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
array& out,
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
const array& c,
|
|
||||||
const mlx::core::Shape& batch_shape,
|
|
||||||
const mlx::core::Strides& a_batch_strides,
|
|
||||||
const mlx::core::Strides& b_batch_strides,
|
|
||||||
const mlx::core::Strides& c_batch_strides,
|
|
||||||
float alpha,
|
|
||||||
float beta) {
|
|
||||||
auto batch_count = out.size() / (M_ * N_);
|
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
|
||||||
set_pointer_mode(c_desc_, batch_count);
|
|
||||||
set_pointer_mode(out_desc_, batch_count);
|
|
||||||
|
|
||||||
// Launch kernel to set device offsets
|
|
||||||
auto pointers = array(
|
|
||||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
|
||||||
{static_cast<int>(batch_count * 4)},
|
|
||||||
uint64);
|
|
||||||
|
|
||||||
encoder.add_temporary(pointers);
|
|
||||||
int block_size = 512;
|
|
||||||
encoder.set_output_array(pointers);
|
|
||||||
encoder.add_kernel_node(
|
|
||||||
cu::set_addmm_device_pointers,
|
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
|
||||||
block_size,
|
|
||||||
0,
|
|
||||||
pointers.data<int8_t*>(),
|
|
||||||
a.data<int8_t>(),
|
|
||||||
b.data<int8_t>(),
|
|
||||||
c.data<int8_t>(),
|
|
||||||
out.data<int8_t>(),
|
|
||||||
static_cast<int>(out.dtype().size()),
|
|
||||||
const_param(batch_shape),
|
|
||||||
const_param(a_batch_strides),
|
|
||||||
const_param(b_batch_strides),
|
|
||||||
const_param(c_batch_strides),
|
|
||||||
static_cast<int64_t>(M_) * N_,
|
|
||||||
static_cast<int>(batch_shape.size()),
|
|
||||||
batch_count);
|
|
||||||
|
|
||||||
// Run matmul
|
|
||||||
encoder.set_input_array(pointers);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_input_array(c);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
auto a_pointers = pointers.data<int8_t*>();
|
|
||||||
auto b_pointers = a_pointers + batch_count;
|
|
||||||
auto c_pointers = b_pointers + batch_count;
|
|
||||||
auto out_pointers = c_pointers + batch_count;
|
|
||||||
run_impl(
|
|
||||||
encoder,
|
|
||||||
reinterpret_cast<void*>(out_pointers),
|
|
||||||
reinterpret_cast<void*>(a_pointers),
|
|
||||||
reinterpret_cast<void*>(b_pointers),
|
|
||||||
reinterpret_cast<void*>(c_pointers),
|
|
||||||
alpha,
|
|
||||||
beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
@@ -7,10 +7,12 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
struct CublasPreference {
|
struct CublasPreference {
|
||||||
CublasPreference(Device& device) {
|
CublasPreference(cu::Device& device) {
|
||||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||||
// for Hopper+:
|
// for Hopper+:
|
||||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||||
@@ -33,7 +35,7 @@ struct CublasPreference {
|
|||||||
cublasLtMatmulPreference_t pref_{nullptr};
|
cublasLtMatmulPreference_t pref_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
|
||||||
static CublasPreference pref(device);
|
static CublasPreference pref(device);
|
||||||
return pref.pref_;
|
return pref.pref_;
|
||||||
}
|
}
|
||||||
@@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
|||||||
return CUBLAS_COMPUTE_64F;
|
return CUBLAS_COMPUTE_64F;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
|
|||||||
return CUDA_C_32F;
|
return CUDA_C_32F;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
|||||||
return desc;
|
return desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::Matmul(
|
} // namespace
|
||||||
Device& device,
|
|
||||||
|
CublasGemm::CublasGemm(
|
||||||
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@@ -155,8 +159,8 @@ Matmul::Matmul(
|
|||||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::Matmul(
|
CublasGemm::CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@@ -171,7 +175,7 @@ Matmul::Matmul(
|
|||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride,
|
int64_t b_batch_stride,
|
||||||
int64_t c_batch_stride)
|
int64_t c_batch_stride)
|
||||||
: Matmul(
|
: CublasGemm(
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@@ -190,7 +194,7 @@ Matmul::Matmul(
|
|||||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
Matmul::~Matmul() {
|
CublasGemm::~CublasGemm() {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||||
@@ -198,7 +202,73 @@ Matmul::~Matmul() {
|
|||||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_impl(
|
void CublasGemm::run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides) {
|
||||||
|
int batch_count = out.size() / (M_ * N_);
|
||||||
|
if (batch_count / batch_shape.back() > 1) {
|
||||||
|
run_batched(
|
||||||
|
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CublasGemm::run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides,
|
||||||
|
const Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
int batch_count = out.size() / (M_ * N_);
|
||||||
|
if (batch_count / batch_shape.back() > 1) {
|
||||||
|
run_batched(
|
||||||
|
encoder,
|
||||||
|
out,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c,
|
||||||
|
batch_shape,
|
||||||
|
a_batch_strides,
|
||||||
|
b_batch_strides,
|
||||||
|
c_batch_strides,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
execute(
|
||||||
|
encoder,
|
||||||
|
out.data<void>(),
|
||||||
|
a.data<void>(),
|
||||||
|
b.data<void>(),
|
||||||
|
c.data<void>(),
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CublasGemm::execute(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
void* out,
|
void* out,
|
||||||
const void* a,
|
const void* a,
|
||||||
@@ -256,29 +326,4 @@ void Matmul::run_impl(
|
|||||||
encoder.stream()));
|
encoder.stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run(
|
} // namespace mlx::core
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
array& out,
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
const std::optional<array>& c /* = std::nullopt */,
|
|
||||||
float alpha /* = 1 */,
|
|
||||||
float beta /* = 0 */) {
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
if (c) {
|
|
||||||
encoder.set_input_array(*c);
|
|
||||||
}
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
run_impl(
|
|
||||||
encoder,
|
|
||||||
out.data<void>(),
|
|
||||||
a.data<void>(),
|
|
||||||
b.data<void>(),
|
|
||||||
c ? c->data<void>() : nullptr,
|
|
||||||
alpha,
|
|
||||||
beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
|
||||||
|
|||||||
@@ -5,13 +5,13 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <optional>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
class Matmul {
|
|
||||||
|
class CublasGemm {
|
||||||
public:
|
public:
|
||||||
Matmul(
|
CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@@ -25,8 +25,8 @@ class Matmul {
|
|||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
int64_t b_batch_stride);
|
int64_t b_batch_stride);
|
||||||
|
|
||||||
Matmul(
|
CublasGemm(
|
||||||
Device& device,
|
cu::Device& device,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
uint64_t a_rows,
|
uint64_t a_rows,
|
||||||
@@ -42,25 +42,39 @@ class Matmul {
|
|||||||
int64_t b_batch_stride,
|
int64_t b_batch_stride,
|
||||||
int64_t c_batch_stride);
|
int64_t c_batch_stride);
|
||||||
|
|
||||||
~Matmul();
|
~CublasGemm();
|
||||||
|
|
||||||
void run(
|
void run(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const std::optional<array>& c = std::nullopt,
|
const Shape& batch_shape,
|
||||||
float alpha = 1,
|
const Strides& a_batch_strides,
|
||||||
float beta = 0);
|
const Strides& b_batch_strides);
|
||||||
|
|
||||||
|
void run(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides,
|
||||||
|
const Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta);
|
||||||
|
|
||||||
|
private:
|
||||||
void run_batched(
|
void run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides);
|
const Strides& b_batch_strides);
|
||||||
|
|
||||||
void run_batched(
|
void run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
@@ -68,15 +82,14 @@ class Matmul {
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides,
|
const Strides& b_batch_strides,
|
||||||
const mlx::core::Strides& c_batch_strides,
|
const Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta);
|
float beta);
|
||||||
|
|
||||||
private:
|
void execute(
|
||||||
void run_impl(
|
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
void* out,
|
void* out,
|
||||||
const void* a,
|
const void* a,
|
||||||
@@ -97,4 +110,4 @@ class Matmul {
|
|||||||
cublasLtMatmulHeuristicResult_t heuristic_;
|
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -4,16 +4,16 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core {
|
||||||
|
|
||||||
void Matmul::run_batched(
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides) {
|
const Strides& b_batch_strides) {
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
@@ -22,7 +22,7 @@ void Matmul::run_batched(
|
|||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
auto concurrent = encoder.concurrent_context();
|
auto concurrent = encoder.concurrent_context();
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
@@ -33,16 +33,16 @@ void Matmul::run_batched(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Matmul::run_batched(
|
void CublasGemm::run_batched(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
array& out,
|
array& out,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
const mlx::core::Shape& batch_shape,
|
const Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides,
|
const Strides& b_batch_strides,
|
||||||
const mlx::core::Strides& c_batch_strides,
|
const Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta) {
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
@@ -56,7 +56,7 @@ void Matmul::run_batched(
|
|||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
auto concurrent = encoder.concurrent_context();
|
auto concurrent = encoder.concurrent_context();
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
run_impl(
|
execute(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||||
@@ -70,4 +70,4 @@ void Matmul::run_batched(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core
|
||||||
327
mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu
Normal file
327
mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
__global__ void set_mm_device_pointers_nd(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data());
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void set_mm_device_pointers_g(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ Shape batch_shape,
|
||||||
|
const __grid_constant__ Strides a_batch_strides,
|
||||||
|
const __grid_constant__ Strides b_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_ndim,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset] = elem_to_loc(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data(),
|
||||||
|
batch_ndim);
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
__global__ void set_addmm_device_pointers_nd(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* c_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data(),
|
||||||
|
c_batch_strides.data());
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||||
|
pointers[index + 3 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void set_addmm_device_pointers_g(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* c_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ Shape batch_shape,
|
||||||
|
const __grid_constant__ Strides a_batch_strides,
|
||||||
|
const __grid_constant__ Strides b_batch_strides,
|
||||||
|
const __grid_constant__ Strides c_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_ndim,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset, c_offset] = elem_to_loc(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data(),
|
||||||
|
c_batch_strides.data(),
|
||||||
|
batch_ndim);
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||||
|
pointers[index + 3 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
||||||
|
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
||||||
|
&batch_mode,
|
||||||
|
sizeof(batch_mode)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void CublasGemm::run_batched(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides) {
|
||||||
|
int batch_count = out.size() / (M_ * N_);
|
||||||
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
|
set_pointer_mode(out_desc_, batch_count);
|
||||||
|
|
||||||
|
// Launch kernel to set device offsets
|
||||||
|
auto pointers = array(
|
||||||
|
allocator::malloc(batch_count * sizeof(void*) * 3),
|
||||||
|
{batch_count * 3},
|
||||||
|
uint64);
|
||||||
|
|
||||||
|
encoder.add_temporary(pointers);
|
||||||
|
encoder.set_output_array(pointers);
|
||||||
|
|
||||||
|
int block_dims = std::min(batch_count, 256);
|
||||||
|
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||||
|
int64_t batch_stride = M_ * N_;
|
||||||
|
int item_size = out.itemsize();
|
||||||
|
|
||||||
|
int ndim = batch_shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_mm_device_pointers_nd<ndim_constant()>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
|
const_param<ndim_constant()>(batch_shape),
|
||||||
|
const_param<ndim_constant()>(a_batch_strides),
|
||||||
|
const_param<ndim_constant()>(b_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
batch_count);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_mm_device_pointers_g,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
|
const_param(batch_shape),
|
||||||
|
const_param(a_batch_strides),
|
||||||
|
const_param(b_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
ndim,
|
||||||
|
batch_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run matmul
|
||||||
|
encoder.set_input_array(pointers);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto a_pointers = pointers.data<int8_t*>();
|
||||||
|
auto b_pointers = a_pointers + batch_count;
|
||||||
|
auto out_pointers = b_pointers + batch_count;
|
||||||
|
execute(
|
||||||
|
encoder,
|
||||||
|
reinterpret_cast<void*>(out_pointers),
|
||||||
|
reinterpret_cast<void*>(a_pointers),
|
||||||
|
reinterpret_cast<void*>(b_pointers),
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CublasGemm::run_batched(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& out,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
const Shape& batch_shape,
|
||||||
|
const Strides& a_batch_strides,
|
||||||
|
const Strides& b_batch_strides,
|
||||||
|
const Strides& c_batch_strides,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
int batch_count = out.size() / (M_ * N_);
|
||||||
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
|
set_pointer_mode(c_desc_, batch_count);
|
||||||
|
set_pointer_mode(out_desc_, batch_count);
|
||||||
|
|
||||||
|
// Launch kernel to set device offsets
|
||||||
|
auto pointers = array(
|
||||||
|
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||||
|
{batch_count * 4},
|
||||||
|
uint64);
|
||||||
|
|
||||||
|
encoder.add_temporary(pointers);
|
||||||
|
encoder.set_output_array(pointers);
|
||||||
|
|
||||||
|
int block_dims = std::min(batch_count, 256);
|
||||||
|
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||||
|
int64_t batch_stride = M_ * N_;
|
||||||
|
int item_size = out.itemsize();
|
||||||
|
|
||||||
|
int ndim = batch_shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_addmm_device_pointers_nd<ndim_constant()>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
|
const_param<ndim_constant()>(batch_shape),
|
||||||
|
const_param<ndim_constant()>(a_batch_strides),
|
||||||
|
const_param<ndim_constant()>(b_batch_strides),
|
||||||
|
const_param<ndim_constant()>(c_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
batch_count);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_addmm_device_pointers_g,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
|
const_param(batch_shape),
|
||||||
|
const_param(a_batch_strides),
|
||||||
|
const_param(b_batch_strides),
|
||||||
|
const_param(c_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
ndim,
|
||||||
|
batch_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run matmul
|
||||||
|
encoder.set_input_array(pointers);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto a_pointers = pointers.data<int8_t*>();
|
||||||
|
auto b_pointers = a_pointers + batch_count;
|
||||||
|
auto c_pointers = b_pointers + batch_count;
|
||||||
|
auto out_pointers = c_pointers + batch_count;
|
||||||
|
execute(
|
||||||
|
encoder,
|
||||||
|
reinterpret_cast<void*>(out_pointers),
|
||||||
|
reinterpret_cast<void*>(a_pointers),
|
||||||
|
reinterpret_cast<void*>(b_pointers),
|
||||||
|
reinterpret_cast<void*>(c_pointers),
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
301
mlx/backend/cuda/gemms/steel_gemm.cu
Normal file
301
mlx/backend/cuda/gemms/steel_gemm.cu
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
#include "mlx/backend/common/matmul.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/gemms/steel_gemm.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/steel/gemm.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/mma.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/tiles.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
struct GemmParams {
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int lda;
|
||||||
|
int ldb;
|
||||||
|
int ldd;
|
||||||
|
|
||||||
|
int NblockM;
|
||||||
|
int NblockN;
|
||||||
|
int NblockK;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int SL,
|
||||||
|
int Nstages>
|
||||||
|
__global__ void kernel_steel_gemm(
|
||||||
|
const T* a,
|
||||||
|
const T* b,
|
||||||
|
T* d,
|
||||||
|
__grid_constant__ const GemmParams params) {
|
||||||
|
const int bM_idx = (blockIdx.y << SL) + (blockIdx.x & ((1 << SL) - 1));
|
||||||
|
const int bN_idx = blockIdx.x >> SL;
|
||||||
|
|
||||||
|
if (params.NblockN <= bN_idx || params.NblockM <= bM_idx) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int d_row = bM_idx * BM;
|
||||||
|
const int d_col = bN_idx * BN;
|
||||||
|
const size_t d_row_long = size_t(d_row);
|
||||||
|
const size_t d_col_long = size_t(d_col);
|
||||||
|
|
||||||
|
a += transpose_a ? d_row_long : d_row_long * params.K;
|
||||||
|
b += transpose_b ? d_col_long * params.K : d_col_long;
|
||||||
|
d += d_row_long * params.ldd + d_col_long;
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
auto warp = cg::tiled_partition<32>(block);
|
||||||
|
|
||||||
|
const int lane_idx = warp.thread_rank();
|
||||||
|
const int warp_idx = warp.meta_group_rank();
|
||||||
|
|
||||||
|
const int wm = warp_idx / WN;
|
||||||
|
const int wn = warp_idx % WN;
|
||||||
|
|
||||||
|
constexpr int SM = BM / WM;
|
||||||
|
constexpr int SN = BN / WN;
|
||||||
|
constexpr int SK = BK;
|
||||||
|
constexpr int TK = SK / 16;
|
||||||
|
|
||||||
|
constexpr int NUM_WARPS = WM * WN;
|
||||||
|
|
||||||
|
// Allocate shared memory
|
||||||
|
extern __shared__ char shmem[];
|
||||||
|
SharedTile<T, BM, BK>(&as)[Nstages] =
|
||||||
|
*(SharedTile<T, BM, BK>(*)[Nstages])(&shmem[0]);
|
||||||
|
SharedTile<T, BN, BK>(&bs)[Nstages] = *(SharedTile<T, BN, BK>(*)[Nstages])(
|
||||||
|
&shmem[sizeof(T) * Nstages * BM * BK]);
|
||||||
|
|
||||||
|
// Allocate registers for the MMA
|
||||||
|
RegisterTile<float, SM, SN> C;
|
||||||
|
RegisterTile<T, SM, 16> A[TK];
|
||||||
|
RegisterTile<T, SN, 16> B[TK];
|
||||||
|
|
||||||
|
// Zero the accumulators
|
||||||
|
C.fill(0);
|
||||||
|
|
||||||
|
// Start gmem -> smem copies
|
||||||
|
int k_block_read = 0;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int bk = 0; bk < (Nstages - 1); bk++) {
|
||||||
|
load_async<NUM_WARPS>(
|
||||||
|
as[bk], as[bk].base_addr(), a + k_block_read, params.K);
|
||||||
|
load_async<NUM_WARPS>(
|
||||||
|
bs[bk], bs[bk].base_addr(), b + k_block_read, params.K);
|
||||||
|
k_block_read += BK;
|
||||||
|
cp_async_commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
int smem_pipe_read = 0;
|
||||||
|
int smem_pipe_write = Nstages - 1;
|
||||||
|
|
||||||
|
// Wait till only 1 remains laoding
|
||||||
|
cp_async_wait<1>();
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
const int offset_m = wm * SM;
|
||||||
|
const int offset_n = wn * SN;
|
||||||
|
|
||||||
|
// Start smem -> register copy
|
||||||
|
A[0].load(
|
||||||
|
as[smem_pipe_read],
|
||||||
|
as[smem_pipe_read].base_addr(),
|
||||||
|
offset_m + lane_idx % 16,
|
||||||
|
lane_idx / 16 * 8);
|
||||||
|
B[0].load(
|
||||||
|
bs[smem_pipe_read],
|
||||||
|
bs[smem_pipe_read].base_addr(),
|
||||||
|
offset_n + lane_idx % 16,
|
||||||
|
lane_idx / 16 * 8);
|
||||||
|
|
||||||
|
// Main loop
|
||||||
|
for (int kb = 0; kb < params.NblockK; kb++) {
|
||||||
|
// Prepare next registers
|
||||||
|
{
|
||||||
|
A[1].load(
|
||||||
|
as[smem_pipe_read],
|
||||||
|
as[smem_pipe_read].base_addr(),
|
||||||
|
offset_m + lane_idx % 16,
|
||||||
|
16 + lane_idx / 16 * 8);
|
||||||
|
B[1].load(
|
||||||
|
bs[smem_pipe_read],
|
||||||
|
bs[smem_pipe_read].base_addr(),
|
||||||
|
offset_n + lane_idx % 16,
|
||||||
|
16 + lane_idx / 16 * 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare next smem
|
||||||
|
if ((kb + Nstages - 1) < params.NblockK) {
|
||||||
|
load_async<NUM_WARPS>(
|
||||||
|
as[smem_pipe_write],
|
||||||
|
as[smem_pipe_write].base_addr(),
|
||||||
|
a + k_block_read,
|
||||||
|
params.K);
|
||||||
|
load_async<NUM_WARPS>(
|
||||||
|
bs[smem_pipe_write],
|
||||||
|
bs[smem_pipe_write].base_addr(),
|
||||||
|
b + k_block_read,
|
||||||
|
params.K);
|
||||||
|
}
|
||||||
|
k_block_read += BK;
|
||||||
|
|
||||||
|
cp_async_commit();
|
||||||
|
|
||||||
|
smem_pipe_write = smem_pipe_read;
|
||||||
|
smem_pipe_read = smem_pipe_read + 1;
|
||||||
|
smem_pipe_read = (smem_pipe_read == Nstages) ? 0 : smem_pipe_read;
|
||||||
|
|
||||||
|
// Do current gemm
|
||||||
|
mma_t(C, A[0], B[0]);
|
||||||
|
|
||||||
|
// Do wait for next register
|
||||||
|
cp_async_wait<1>();
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
// Prepare next register (smem_pipe_read has moved to the next)
|
||||||
|
{
|
||||||
|
A[0].load(
|
||||||
|
as[smem_pipe_read],
|
||||||
|
as[smem_pipe_read].base_addr(),
|
||||||
|
offset_m + lane_idx % 16,
|
||||||
|
lane_idx / 16 * 8);
|
||||||
|
B[0].load(
|
||||||
|
bs[smem_pipe_read],
|
||||||
|
bs[smem_pipe_read].base_addr(),
|
||||||
|
offset_n + lane_idx % 16,
|
||||||
|
lane_idx / 16 * 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do current gemm
|
||||||
|
mma_t(C, A[1], B[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait and clear
|
||||||
|
cp_async_wait_all();
|
||||||
|
block.sync();
|
||||||
|
|
||||||
|
C.store_global(d, params.ldd, offset_m, offset_n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void dispatch_steel_gemm(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& d,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
int ldd,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed) {
|
||||||
|
using DataType = cuda_type_t<float16_t>;
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(d);
|
||||||
|
|
||||||
|
constexpr int BM = 128;
|
||||||
|
constexpr int BN = 128;
|
||||||
|
constexpr int BK = 32;
|
||||||
|
|
||||||
|
constexpr int WM = 2;
|
||||||
|
constexpr int WN = 2;
|
||||||
|
|
||||||
|
constexpr int SL = 0;
|
||||||
|
constexpr int Nstages = 3;
|
||||||
|
|
||||||
|
constexpr uint32_t smem_bytes = BK * (BM + BN) * Nstages * sizeof(DataType);
|
||||||
|
|
||||||
|
const int NblockM = (M + BM - 1) / BM;
|
||||||
|
const int NblockN = (N + BN - 1) / BN;
|
||||||
|
const int NblockK = (K + BK - 1) / BK;
|
||||||
|
|
||||||
|
cu::GemmParams params{
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* int ldd = */ ldd,
|
||||||
|
|
||||||
|
/* int NblockM = */ NblockM,
|
||||||
|
/* int NblockN = */ NblockN,
|
||||||
|
/* int NblockK = */ NblockK,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Prepare launch grid params
|
||||||
|
int tile = 1 << SL;
|
||||||
|
int tm = (NblockM + tile - 1) / tile;
|
||||||
|
int tn = NblockN * tile;
|
||||||
|
|
||||||
|
dim3 grid_dim(tn, tm, 1);
|
||||||
|
dim3 block_dim(32 * WM * WN, 1, 1);
|
||||||
|
|
||||||
|
dispatch_bool(a_transposed, [&](auto ta_) {
|
||||||
|
dispatch_bool(b_transposed, [&](auto tb_) {
|
||||||
|
constexpr bool ta = ta_.value;
|
||||||
|
constexpr bool tb = tb_.value;
|
||||||
|
|
||||||
|
auto kernel = cu::ab_t_aligned<DataType, BM, BN, BK>;
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||||
|
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid_dim,
|
||||||
|
block_dim,
|
||||||
|
smem_bytes,
|
||||||
|
a.data<DataType>(),
|
||||||
|
b.data<DataType>(),
|
||||||
|
d.data<DataType>(),
|
||||||
|
N,
|
||||||
|
K);
|
||||||
|
|
||||||
|
// auto kernel = cu::kernel_steel_gemm<DataType, BM, BN, BK, WM, WN, ta,
|
||||||
|
// tb, SL, Nstages>;
|
||||||
|
|
||||||
|
// cudaFuncSetAttribute(kernel,
|
||||||
|
// cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||||
|
|
||||||
|
// encoder.add_kernel_node(
|
||||||
|
// kernel,
|
||||||
|
// grid_dim,
|
||||||
|
// block_dim,
|
||||||
|
// smem_bytes,
|
||||||
|
// a.data<DataType>(),
|
||||||
|
// b.data<DataType>(),
|
||||||
|
// d.data<DataType>(),
|
||||||
|
// params);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
27
mlx/backend/cuda/gemms/steel_gemm.h
Normal file
27
mlx/backend/cuda/gemms/steel_gemm.h
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/matmul.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void dispatch_steel_gemm(
|
||||||
|
const Stream& s,
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& d,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
int ldd,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -7,6 +7,8 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/gemms/steel_gemm.h"
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
@@ -95,9 +97,27 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (out.dtype() == float16 && batch_count == 1 && !a_transposed &&
|
||||||
|
b_transposed) {
|
||||||
|
return dispatch_steel_gemm(
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* cu::CommandEncoder& encoder = */ encoder,
|
||||||
|
/* const array& a = */ a,
|
||||||
|
/* const array& b = */ b,
|
||||||
|
/* array& d = */ out,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* int ldd = */ N,
|
||||||
|
/* bool a_transposed = */ a_transposed,
|
||||||
|
/* bool b_transposed = */ b_transposed);
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
cu::Matmul matmul(
|
CublasGemm gemm(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@@ -111,14 +131,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
batch_shape.back(),
|
batch_shape.back(),
|
||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||||
if ((batch_count / batch_shape.back()) == 1) {
|
|
||||||
matmul.run(encoder, out, a, b);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
matmul.run_batched(
|
|
||||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -186,7 +199,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
|
|
||||||
cu::Matmul matmul(
|
CublasGemm gemm(
|
||||||
cu::device(s.device),
|
cu::device(s.device),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
a_transposed,
|
a_transposed,
|
||||||
@@ -202,12 +215,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
gemm.run(
|
||||||
if ((batch_count / batch_shape.back()) == 1) {
|
|
||||||
matmul.run(encoder, out, a, b, c, alpha_, beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
matmul.run_batched(
|
|
||||||
encoder,
|
encoder,
|
||||||
out,
|
out,
|
||||||
a,
|
a,
|
||||||
|
|||||||
@@ -8,19 +8,13 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
// cudnn_frontend.h redefines this macro.
|
|
||||||
#undef CHECK_CUDA_ERROR
|
|
||||||
|
|
||||||
#include <cudnn_frontend.h>
|
|
||||||
#include <fmt/format.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
|
|
||||||
namespace fe = cudnn_frontend;
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
@@ -645,294 +639,6 @@ void sdpa_vector_fallback(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SDPACacheKey {
|
|
||||||
int device_id;
|
|
||||||
fe::DataType_t cudnn_type;
|
|
||||||
|
|
||||||
int B;
|
|
||||||
int H;
|
|
||||||
int D;
|
|
||||||
|
|
||||||
int qL;
|
|
||||||
int kL;
|
|
||||||
|
|
||||||
int gqa_factor;
|
|
||||||
float scale;
|
|
||||||
|
|
||||||
int64_t Q_strides[3];
|
|
||||||
int64_t K_strides[3];
|
|
||||||
int64_t V_strides[3];
|
|
||||||
int64_t O_strides[3];
|
|
||||||
|
|
||||||
bool generate_stats;
|
|
||||||
bool causal_mask;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto& sdpa_cache() {
|
|
||||||
static LRUBytesKeyCache<SDPACacheKey, std::shared_ptr<fe::graph::Graph>>
|
|
||||||
cache(
|
|
||||||
/* capacity */ 128);
|
|
||||||
return cache;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define Q_UID 1
|
|
||||||
#define K_UID 2
|
|
||||||
#define V_UID 3
|
|
||||||
#define O_UID 4
|
|
||||||
#define STATS_UID 5
|
|
||||||
|
|
||||||
std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const SDPACacheKey& cache_key) {
|
|
||||||
// Check if graph has already been fully built
|
|
||||||
if (auto it = sdpa_cache().find(cache_key); it != sdpa_cache().end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up new graph
|
|
||||||
auto graph = std::make_shared<fe::graph::Graph>();
|
|
||||||
|
|
||||||
graph->set_io_data_type(cache_key.cudnn_type)
|
|
||||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
|
||||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
|
||||||
|
|
||||||
auto Q = graph->tensor(
|
|
||||||
fe::graph::Tensor_attributes()
|
|
||||||
.set_name("Q")
|
|
||||||
.set_uid(Q_UID)
|
|
||||||
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
|
|
||||||
.set_stride(
|
|
||||||
{cache_key.Q_strides[0],
|
|
||||||
cache_key.Q_strides[1],
|
|
||||||
cache_key.Q_strides[2],
|
|
||||||
1}));
|
|
||||||
|
|
||||||
int h_kv = cache_key.H / cache_key.gqa_factor;
|
|
||||||
auto K =
|
|
||||||
graph->tensor(fe::graph::Tensor_attributes()
|
|
||||||
.set_name("K")
|
|
||||||
.set_uid(K_UID)
|
|
||||||
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
|
|
||||||
.set_stride(
|
|
||||||
{cache_key.K_strides[0],
|
|
||||||
cache_key.K_strides[1],
|
|
||||||
cache_key.V_strides[2],
|
|
||||||
1}));
|
|
||||||
|
|
||||||
auto V =
|
|
||||||
graph->tensor(fe::graph::Tensor_attributes()
|
|
||||||
.set_name("V")
|
|
||||||
.set_uid(V_UID)
|
|
||||||
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
|
|
||||||
.set_stride(
|
|
||||||
{cache_key.V_strides[0],
|
|
||||||
cache_key.V_strides[1],
|
|
||||||
cache_key.V_strides[2],
|
|
||||||
1}));
|
|
||||||
|
|
||||||
auto sdpa_options = fe::graph::SDPA_attributes()
|
|
||||||
.set_name("flash_attention")
|
|
||||||
.set_is_inference(!cache_key.generate_stats)
|
|
||||||
.set_attn_scale(cache_key.scale);
|
|
||||||
|
|
||||||
if (cache_key.causal_mask && cache_key.qL > 1) {
|
|
||||||
sdpa_options.set_diagonal_alignment(fe::DiagonalAlignment_t::TOP_LEFT)
|
|
||||||
.set_diagonal_band_right_bound(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);
|
|
||||||
|
|
||||||
O->set_output(true)
|
|
||||||
.set_uid(O_UID)
|
|
||||||
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
|
|
||||||
.set_stride(
|
|
||||||
{cache_key.O_strides[0],
|
|
||||||
cache_key.O_strides[1],
|
|
||||||
cache_key.O_strides[2],
|
|
||||||
1});
|
|
||||||
|
|
||||||
if (cache_key.generate_stats) {
|
|
||||||
Stats->set_output(true)
|
|
||||||
.set_data_type(fe::DataType_t::FLOAT)
|
|
||||||
.set_uid(STATS_UID);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build and Validate cudnn graph
|
|
||||||
|
|
||||||
auto handle = encoder.device().cudnn_handle();
|
|
||||||
|
|
||||||
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
|
||||||
if (cudnnGetVersion() < 90600) {
|
|
||||||
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
|
|
||||||
if (!build_status.is_good()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Unable to build cudnn graph for attention."
|
|
||||||
" Failed with message: " +
|
|
||||||
build_status.get_message());
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
auto val_status = graph->validate();
|
|
||||||
auto op_status = graph->build_operation_graph(handle);
|
|
||||||
|
|
||||||
auto plan_stauts =
|
|
||||||
graph->create_execution_plans({cudnn_frontend::HeurMode_t::A});
|
|
||||||
if (!plan_stauts.is_good()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Unable to create exec plan for cudnn attention."
|
|
||||||
" Failed with message: " +
|
|
||||||
plan_stauts.get_message());
|
|
||||||
}
|
|
||||||
|
|
||||||
graph->select_behavior_notes(
|
|
||||||
{cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
|
||||||
|
|
||||||
auto support_status = graph->check_support(handle);
|
|
||||||
if (!support_status.is_good()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"No cuda graph support for cudnn attention."
|
|
||||||
" Failed with message: " +
|
|
||||||
support_status.get_message());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto build_status = graph->build_plans(handle);
|
|
||||||
if (!build_status.is_good()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Unable to build cudnn graph for attention."
|
|
||||||
" Failed with message: " +
|
|
||||||
build_status.get_message());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto [it, _] = sdpa_cache().emplace(cache_key, graph);
|
|
||||||
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case int8:
|
|
||||||
return fe::DataType_t::INT8;
|
|
||||||
case int32:
|
|
||||||
return fe::DataType_t::INT32;
|
|
||||||
case uint8:
|
|
||||||
return fe::DataType_t::UINT8;
|
|
||||||
case float16:
|
|
||||||
return fe::DataType_t::HALF;
|
|
||||||
case bfloat16:
|
|
||||||
return fe::DataType_t::BFLOAT16;
|
|
||||||
case float32:
|
|
||||||
return fe::DataType_t::FLOAT;
|
|
||||||
case float64:
|
|
||||||
return fe::DataType_t::DOUBLE;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Unsupported dtype in SDPA: {}.", dtype_to_string(dtype)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void sdpa_cudnn(
|
|
||||||
const Stream& s,
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
const array& q,
|
|
||||||
const array& k,
|
|
||||||
const array& v,
|
|
||||||
const float scale,
|
|
||||||
array& o,
|
|
||||||
bool do_causal_ = false) {
|
|
||||||
encoder.set_input_array(q);
|
|
||||||
encoder.set_input_array(k);
|
|
||||||
encoder.set_input_array(v);
|
|
||||||
encoder.set_output_array(o);
|
|
||||||
|
|
||||||
auto cudnn_type = dtype_to_cudnn_type(q.dtype());
|
|
||||||
|
|
||||||
int B = q.shape(0);
|
|
||||||
int H = q.shape(1);
|
|
||||||
int D = q.shape(3);
|
|
||||||
int gqa_factor = q.shape(1) / k.shape(1);
|
|
||||||
|
|
||||||
int qL = q.shape(2);
|
|
||||||
int kL = k.shape(2);
|
|
||||||
|
|
||||||
SDPACacheKey cache_key{
|
|
||||||
/* int device_id = */ encoder.device().cuda_device(),
|
|
||||||
/* fe::DataType_t cudnn_type = */ cudnn_type,
|
|
||||||
|
|
||||||
/* int B = */ B,
|
|
||||||
/* int H = */ H,
|
|
||||||
/* int D = */ D,
|
|
||||||
|
|
||||||
/* int qL = */ qL,
|
|
||||||
/* int kL = */ kL,
|
|
||||||
|
|
||||||
/* int gqa_factor = */ gqa_factor,
|
|
||||||
/* float scale = */ scale,
|
|
||||||
|
|
||||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
|
||||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
|
||||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
|
||||||
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)},
|
|
||||||
|
|
||||||
/* bool generate_stats = */ false,
|
|
||||||
/* bool causal_mask = */ do_causal_};
|
|
||||||
|
|
||||||
auto graph = get_sdpa_forward_graph(encoder, cache_key);
|
|
||||||
|
|
||||||
int64_t workspace_size = 0;
|
|
||||||
auto workspace_status = graph->get_workspace_size(workspace_size);
|
|
||||||
if (!workspace_status.is_good()) {
|
|
||||||
throw std::runtime_error("Unable to get workspace for cudnn attention.");
|
|
||||||
}
|
|
||||||
|
|
||||||
array workspace(
|
|
||||||
allocator::malloc(workspace_size), {int(workspace_size)}, uint8);
|
|
||||||
auto workspace_ptr = workspace.data<void>();
|
|
||||||
|
|
||||||
std::unordered_map<int64_t, void*> variant_pack = {
|
|
||||||
{Q_UID, const_cast<void*>(q.data<void>())},
|
|
||||||
{K_UID, const_cast<void*>(k.data<void>())},
|
|
||||||
{V_UID, const_cast<void*>(v.data<void>())},
|
|
||||||
{O_UID, o.data<void>()}};
|
|
||||||
|
|
||||||
auto handle = encoder.device().cudnn_handle();
|
|
||||||
cudnnSetStream(handle, encoder.stream());
|
|
||||||
|
|
||||||
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
|
||||||
if (cudnnGetVersion() < 90600) {
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
auto exec_status = graph->execute(handle, variant_pack, workspace_ptr);
|
|
||||||
|
|
||||||
if (!exec_status.is_good()) {
|
|
||||||
capture.discard = true;
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Unable to execute cudnn attention."
|
|
||||||
" Failed with message: " +
|
|
||||||
exec_status.get_message());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cudaGraph_t cu_graph;
|
|
||||||
cudaGraphCreate(&cu_graph, 0);
|
|
||||||
|
|
||||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
|
||||||
&cu_graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
|
||||||
|
|
||||||
auto cu_graph_status = graph->populate_cuda_graph(
|
|
||||||
handle, variant_pack, workspace_ptr, cu_graph);
|
|
||||||
|
|
||||||
if (!cu_graph_status.is_good()) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Unable to add cuda graph for cudnn attention."
|
|
||||||
" Failed with message: " +
|
|
||||||
cu_graph_status.get_message());
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder.add_graph_node(cu_graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
@@ -945,6 +651,9 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
bool has_arr_mask,
|
bool has_arr_mask,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
|
if (detail::in_grad_tracing()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (s.device == Device::cpu) {
|
if (s.device == Device::cpu) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -960,15 +669,7 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
const bool supported_vector_config =
|
const bool supported_vector_config =
|
||||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||||
|
|
||||||
auto& cu_device = cu::device(s.device);
|
const bool supported_config = supported_vector_config;
|
||||||
|
|
||||||
const bool supported_matrix_config = query_sequence_length > 4 &&
|
|
||||||
cu_device.compute_capability_major() >= 8 &&
|
|
||||||
query_sequence_length == key_sequence_length &&
|
|
||||||
(q.dtype() == float16 || q.dtype() == bfloat16);
|
|
||||||
|
|
||||||
const bool supported_config =
|
|
||||||
(supported_matrix_config || supported_vector_config);
|
|
||||||
|
|
||||||
return has_arr_mask || !supported_config;
|
return has_arr_mask || !supported_config;
|
||||||
}
|
}
|
||||||
@@ -1002,10 +703,6 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto is_matrix_contiguous = [](const array& arr) {
|
|
||||||
return arr.strides(-1) == 1;
|
|
||||||
};
|
|
||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
if (q_pre.shape(2) < 4) {
|
if (q_pre.shape(2) < 4) {
|
||||||
auto q_copy_unless = [](const array& arr) {
|
auto q_copy_unless = [](const array& arr) {
|
||||||
@@ -1059,7 +756,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
array::Flags flags{
|
array::Flags flags{
|
||||||
/* bool contiguous = */ 1,
|
/* bool contiguous = */ 1,
|
||||||
/* bool row_contiguous = */ 0,
|
/* bool row_contiguous = */ o.shape(2) == 1,
|
||||||
/* bool col_contiguous = */ 0,
|
/* bool col_contiguous = */ 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1073,35 +770,9 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full attention mode
|
// Full attention mode should never reach here
|
||||||
else {
|
else {
|
||||||
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
|
throw std::runtime_error("Doesn't support matrix yet.");
|
||||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
|
||||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
|
||||||
|
|
||||||
for (const auto& cp : copies) {
|
|
||||||
encoder.add_temporary(cp);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t str_oD = 1;
|
|
||||||
int64_t str_oH = o.shape(3);
|
|
||||||
int64_t str_oL = o.shape(1) * str_oH;
|
|
||||||
int64_t str_oB = o.shape(2) * str_oL;
|
|
||||||
size_t data_size = o.shape(0) * str_oB;
|
|
||||||
|
|
||||||
array::Flags flags{
|
|
||||||
/* bool contiguous = */ 1,
|
|
||||||
/* bool row_contiguous = */ 0,
|
|
||||||
/* bool col_contiguous = */ 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
o.set_data(
|
|
||||||
allocator::malloc(o.nbytes()),
|
|
||||||
data_size,
|
|
||||||
{str_oB, str_oH, str_oL, str_oD},
|
|
||||||
flags);
|
|
||||||
|
|
||||||
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -143,85 +143,87 @@ struct Tile16x16 {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
// /**
|
||||||
* A simple container of multiple Tile16x16.
|
// * A simple container of multiple Tile16x16.
|
||||||
*
|
// *
|
||||||
* Provides utility functions for loading and manipulating collections of basic
|
// * Provides utility functions for loading and manipulating collections of
|
||||||
* tiles.
|
// basic
|
||||||
*/
|
// * tiles.
|
||||||
template <typename T, int ROWS_, int COLS_>
|
// */
|
||||||
struct RegisterTile {
|
// template <typename T, int ROWS_, int COLS_>
|
||||||
static constexpr int ROWS = ROWS_;
|
// struct RegisterTile {
|
||||||
static constexpr int COLS = COLS_;
|
// static constexpr int ROWS = ROWS_;
|
||||||
static constexpr int TILES_X = COLS / 16;
|
// static constexpr int COLS = COLS_;
|
||||||
static constexpr int TILES_Y = ROWS / 16;
|
// static constexpr int TILES_X = COLS / 16;
|
||||||
|
// static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
|
||||||
Tile16x16<T> data[TILES_X * TILES_Y];
|
// Tile16x16<T> data[TILES_X * TILES_Y];
|
||||||
|
|
||||||
__device__ inline void fill(T v) {
|
// __device__ inline void fill(T v) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
data[i * TILES_X + j].fill(v);
|
// data[i * TILES_X + j].fill(v);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename Tile>
|
// template <typename Tile>
|
||||||
__device__ __forceinline__ void
|
// __device__ __forceinline__ void
|
||||||
load(Tile& tile, uint32_t base_address, int row, int col) {
|
// load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
data[i * TILES_X + j].load(
|
// data[i * TILES_X + j].load(
|
||||||
tile.loc(base_address, row + i * 16, col + j * 16));
|
// tile.loc(base_address, row + i * 16, col + j * 16));
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename Tile, typename F>
|
// template <typename Tile, typename F>
|
||||||
__device__ __forceinline__ void
|
// __device__ __forceinline__ void
|
||||||
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
// load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
f(data[i * TILES_X + j],
|
// f(data[i * TILES_X + j],
|
||||||
tile,
|
// tile,
|
||||||
base_address,
|
// base_address,
|
||||||
row + i * 16,
|
// row + i * 16,
|
||||||
col + j * 16);
|
// col + j * 16);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename U>
|
// template <typename U>
|
||||||
__device__ inline void store_global(U* x, int N, int row, int col) {
|
// __device__ inline void store_global(U* x, int N, int row, int col) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
data[i * TILES_X + j].store_global(
|
// data[i * TILES_X + j].store_global(
|
||||||
x + (row + i * 16) * N + col + j * 16, N);
|
// x + (row + i * 16) * N + col + j * 16, N);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename U>
|
// template <typename U>
|
||||||
__device__ inline void
|
// __device__ inline void
|
||||||
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
// store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int i = 0; i < TILES_Y; i++) {
|
// for (int i = 0; i < TILES_Y; i++) {
|
||||||
MLX_UNROLL
|
// MLX_UNROLL
|
||||||
for (int j = 0; j < TILES_X; j++) {
|
// for (int j = 0; j < TILES_X; j++) {
|
||||||
data[i * TILES_X + j].store_global_safe(
|
// data[i * TILES_X + j].store_global_safe(
|
||||||
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
// x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i *
|
||||||
}
|
// 16);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
};
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple container of multiple Tile16x16.
|
* A simple container of multiple Tile16x16.
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ struct CommandEncoder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Outputs of all kernels in the encoder including temporaries
|
// Outputs of all kernels in the encoder including temporaries
|
||||||
std::unordered_set<const void*> outputs() {
|
std::unordered_set<const void*>& outputs() {
|
||||||
return all_outputs_;
|
return all_outputs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -134,6 +134,10 @@ instantiate_and_or(and, And)
|
|||||||
instantiate_and_or(or, Or)
|
instantiate_and_or(or, Or)
|
||||||
|
|
||||||
#define instantiate_sum_prod(name, op) \
|
#define instantiate_sum_prod(name, op) \
|
||||||
|
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||||
|
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||||
|
|||||||
@@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
|
|||||||
const std::string& op_name) {
|
const std::string& op_name) {
|
||||||
if (op_name == "sum" || op_name == "prod") {
|
if (op_name == "sum" || op_name == "prod") {
|
||||||
if (issubdtype(in.dtype(), integer)) {
|
if (issubdtype(in.dtype(), integer)) {
|
||||||
switch (in.dtype().size()) {
|
switch (in.dtype()) {
|
||||||
case 1:
|
case uint8:
|
||||||
|
return {uint8, uint32};
|
||||||
|
case uint16:
|
||||||
|
return {uint16, uint32};
|
||||||
|
case uint32:
|
||||||
|
return {uint32, uint32};
|
||||||
|
case uint64:
|
||||||
|
return {uint64, uint64};
|
||||||
|
case int8:
|
||||||
return {int8, int32};
|
return {int8, int32};
|
||||||
case 2:
|
case int16:
|
||||||
return {int16, int32};
|
return {int16, int32};
|
||||||
case 4:
|
case int32:
|
||||||
return {int32, int32};
|
return {int32, int32};
|
||||||
case 8:
|
case int64:
|
||||||
return {int64, int64};
|
return {int64, int64};
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported integer type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
|
|||||||
31
mlx/ops.cpp
31
mlx/ops.cpp
@@ -2381,9 +2381,20 @@ array logsumexp(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
||||||
}
|
}
|
||||||
|
bool reduce_last_dim =
|
||||||
|
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||||
|
if (reduce_last_dim) {
|
||||||
|
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||||
|
// is [1, 1, ..., N].
|
||||||
|
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||||
|
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||||
|
reduce_last_dim = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||||
if (!is_complex && axes.size() == 1 &&
|
if (!is_complex && reduce_last_dim) {
|
||||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto out_shape = a.shape();
|
auto out_shape = a.shape();
|
||||||
out_shape.back() = 1;
|
out_shape.back() = 1;
|
||||||
@@ -3403,10 +3414,20 @@ array softmax(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
||||||
}
|
}
|
||||||
|
bool reduce_last_dim =
|
||||||
|
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||||
|
if (reduce_last_dim) {
|
||||||
|
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||||
|
// is [1, 1, ..., N].
|
||||||
|
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||||
|
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||||
|
reduce_last_dim = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||||
if (!is_complex && axes.size() == 1 &&
|
if (!is_complex && reduce_last_dim) {
|
||||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 27
|
#define MLX_VERSION_MINOR 28
|
||||||
#define MLX_VERSION_PATCH 1
|
#define MLX_VERSION_PATCH 0
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,6 @@
|
|||||||
requires = [
|
requires = [
|
||||||
"setuptools>=80",
|
"setuptools>=80",
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.4.0",
|
||||||
"cmake>=3.25",
|
"cmake>=3.25,<4.1",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class Module(dict):
|
|||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
new_weights = dict(weights)
|
new_weights = dict(weights)
|
||||||
curr_weights = dict(tree_flatten(self.parameters()))
|
curr_weights = tree_flatten(self.parameters(), destination={})
|
||||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||||
num_extra = len(extras)
|
num_extra = len(extras)
|
||||||
extras = ",\n".join(sorted(extras))
|
extras = ",\n".join(sorted(extras))
|
||||||
@@ -212,7 +212,7 @@ class Module(dict):
|
|||||||
- ``.npz`` will use :func:`mx.savez`
|
- ``.npz`` will use :func:`mx.savez`
|
||||||
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
||||||
"""
|
"""
|
||||||
params_dict = dict(tree_flatten(self.parameters()))
|
params_dict = tree_flatten(self.parameters(), destination={})
|
||||||
|
|
||||||
if file.endswith(".npz"):
|
if file.endswith(".npz"):
|
||||||
mx.savez(file, **params_dict)
|
mx.savez(file, **params_dict)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def tree_map(
|
def tree_map(
|
||||||
@@ -114,8 +114,11 @@ def tree_map_with_path(
|
|||||||
|
|
||||||
|
|
||||||
def tree_flatten(
|
def tree_flatten(
|
||||||
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
tree: Any,
|
||||||
) -> Any:
|
prefix: str = "",
|
||||||
|
is_leaf: Optional[Callable] = None,
|
||||||
|
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
|
||||||
|
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
||||||
"""Flattens a Python tree to a list of key, value tuples.
|
"""Flattens a Python tree to a list of key, value tuples.
|
||||||
|
|
||||||
The keys are using the dot notation to define trees of arbitrary depth and
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
@@ -128,9 +131,12 @@ def tree_flatten(
|
|||||||
print(tree_flatten([[[0]]]))
|
print(tree_flatten([[[0]]]))
|
||||||
# [("0.0.0", 0)]
|
# [("0.0.0", 0)]
|
||||||
|
|
||||||
print(tree_flatten([[[0]]], ".hello"))
|
print(tree_flatten([[[0]]], prefix=".hello"))
|
||||||
# [("hello.0.0.0", 0)]
|
# [("hello.0.0.0", 0)]
|
||||||
|
|
||||||
|
tree_flatten({"a": {"b": 1}}, destination={})
|
||||||
|
{"a.b": 1}
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Dictionaries should have keys that are valid Python identifiers.
|
Dictionaries should have keys that are valid Python identifiers.
|
||||||
|
|
||||||
@@ -140,26 +146,50 @@ def tree_flatten(
|
|||||||
always discarded.
|
always discarded.
|
||||||
is_leaf (callable): An optional callable that returns True if the
|
is_leaf (callable): An optional callable that returns True if the
|
||||||
passed object is considered a leaf or False otherwise.
|
passed object is considered a leaf or False otherwise.
|
||||||
|
destination (list or dict, optional): A list or dictionary to store the
|
||||||
|
flattened tree. If None an empty list will be used. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, Any]]: The flat representation of the Python tree.
|
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
|
||||||
|
the Python tree.
|
||||||
"""
|
"""
|
||||||
flat_tree = []
|
if destination is None:
|
||||||
|
destination = []
|
||||||
|
|
||||||
if is_leaf is None or not is_leaf(tree):
|
# Create the function to update the destination. We are taking advantage of
|
||||||
if isinstance(tree, (list, tuple)):
|
# the fact that list.extend and dict.update have the same API to simplify
|
||||||
for i, t in enumerate(tree):
|
# the code a bit.
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
if isinstance(destination, list):
|
||||||
return flat_tree
|
_add_to_destination = destination.extend
|
||||||
if isinstance(tree, dict):
|
elif isinstance(destination, dict):
|
||||||
for k, t in tree.items():
|
_add_to_destination = destination.update
|
||||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
else:
|
||||||
return flat_tree
|
raise ValueError("Destination should be either a list or a dictionary or None")
|
||||||
|
|
||||||
return [(prefix[1:], tree)]
|
# Leaf identified by is_leaf so add it and return
|
||||||
|
if is_leaf is not None and is_leaf(tree):
|
||||||
|
_add_to_destination([(prefix[1:], tree)])
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# List or tuple so recursively add each subtree
|
||||||
|
if isinstance(tree, (list, tuple)):
|
||||||
|
for i, item in enumerate(tree):
|
||||||
|
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# Dictionary so recursively add each subtree
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
for key, value in tree.items():
|
||||||
|
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
|
||||||
|
return destination
|
||||||
|
|
||||||
|
# Leaf so add it and return
|
||||||
|
_add_to_destination([(prefix[1:], tree)])
|
||||||
|
|
||||||
|
return destination
|
||||||
|
|
||||||
|
|
||||||
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
||||||
"""Recreate a Python tree from its flat representation.
|
"""Recreate a Python tree from its flat representation.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
print(d)
|
print(d)
|
||||||
# {"hello": {"world": 42}}
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
|
d = tree_unflatten({"hello.world": 42})
|
||||||
|
print(d)
|
||||||
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
|
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
|
||||||
For instance as returned by :meth:`tree_flatten`.
|
For instance as returned by :meth:`tree_flatten`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python tree.
|
A Python tree.
|
||||||
"""
|
"""
|
||||||
if len(tree) == 1 and tree[0][0] == "":
|
items = tree.items() if isinstance(tree, dict) else tree
|
||||||
return tree[0][1]
|
|
||||||
|
|
||||||
try:
|
# Special case when we have just one element in the tree ie not a tree
|
||||||
int(tree[0][0].split(".", maxsplit=1)[0])
|
if len(items) == 1:
|
||||||
is_list = True
|
key, value = next(iter(items))
|
||||||
except ValueError:
|
if key == "":
|
||||||
is_list = False
|
return value
|
||||||
|
|
||||||
# collect children
|
# collect children
|
||||||
children = defaultdict(list)
|
children = defaultdict(list)
|
||||||
for key, value in tree:
|
for key, value in items:
|
||||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||||
next_idx = "" if not next_idx else next_idx[0]
|
next_idx = "" if not next_idx else next_idx[0]
|
||||||
children[current_idx].append((next_idx, value))
|
children[current_idx].append((next_idx, value))
|
||||||
|
|
||||||
# recursively map them to the original container
|
# Assume they are a list and fail to dict if the keys are not all integers
|
||||||
if is_list:
|
try:
|
||||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||||
l = []
|
l = []
|
||||||
for i, k in keys:
|
for i, k in keys:
|
||||||
@@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
|||||||
l.extend([{} for _ in range(i - len(l))])
|
l.extend([{} for _ in range(i - len(l))])
|
||||||
l.append(tree_unflatten(children[k]))
|
l.append(tree_unflatten(children[k]))
|
||||||
return l
|
return l
|
||||||
else:
|
except ValueError:
|
||||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
return {k: tree_unflatten(v) for k, v in children.items()}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
|
||||||
|
|
||||||
model = DictModule()
|
model = DictModule()
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
self.assertEqual(len(params), 2)
|
self.assertEqual(len(params), 2)
|
||||||
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
|
||||||
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))
|
||||||
|
|||||||
@@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
|
|||||||
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sum and prod overflow
|
||||||
|
{
|
||||||
|
auto a = full({256, 2, 2}, 1u, uint8);
|
||||||
|
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
|
||||||
|
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||||
|
|
||||||
|
a = full({65535, 2, 2}, 1u, uint16);
|
||||||
|
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
|
||||||
|
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test gpu reduce with axes") {
|
||||||
// reducing only some axes and irregular layouts
|
// reducing only some axes and irregular layouts
|
||||||
{
|
{
|
||||||
array a(1.0f);
|
array a(1.0f);
|
||||||
|
|||||||
@@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
|
|||||||
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test unsigned sum
|
||||||
|
{
|
||||||
|
const int num_elems = 1000;
|
||||||
|
|
||||||
|
auto x = astype(full({num_elems}, 255), uint8);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
|
||||||
|
|
||||||
|
x = astype(full({num_elems}, 65535), uint16);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
|
||||||
|
|
||||||
|
x = full({3, 3, 3}, 10000, uint32);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
|
||||||
|
|
||||||
|
x = full({3, 3, 3}, 10000, uint64);
|
||||||
|
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
|
||||||
|
}
|
||||||
|
|
||||||
// Test prod
|
// Test prod
|
||||||
{
|
{
|
||||||
auto x = array({});
|
auto x = array({});
|
||||||
@@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
|
|||||||
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test unsigned prod
|
||||||
|
{
|
||||||
|
auto x = array({255, 255}, {2}, uint8);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
|
||||||
|
|
||||||
|
x = array({65535, 2}, {2}, uint16);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
|
||||||
|
|
||||||
|
x = array({100000, 2}, {2}, uint32);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
|
||||||
|
|
||||||
|
x = array({100000, 2}, {2}, uint64);
|
||||||
|
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
|
||||||
|
}
|
||||||
|
|
||||||
// Test all
|
// Test all
|
||||||
{
|
{
|
||||||
auto x = array({});
|
auto x = array({});
|
||||||
|
|||||||
Reference in New Issue
Block a user