Compare commits

..

8 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a22d0bf273 Add stricter condition to matrix sdpa 2025-08-06 19:51:14 -07:00
Jagrit Digani
99d8de8445 Fix cudnn routing 2025-08-06 15:05:58 -07:00
Jagrit Digani
c66b76a8c8 Update routing 2025-08-06 15:01:15 -07:00
Jagrit Digani
f81edd184f Complete 2 pass sdpav 2025-08-06 13:57:40 -07:00
Jagrit Digani
7f8ba2a003 [WIP] 2 pass sdpav 2025-08-06 09:56:39 -07:00
Jagrit Digani
c28249b81a Add more nvtx range for debug 2025-08-06 09:56:39 -07:00
Jagrit Digani
e74bcdc5e3 Add sdpa file 2025-08-06 09:56:39 -07:00
Jagrit Digani
d8ed6c1aa3 Add base cudnn attention support 2025-08-06 09:56:39 -07:00
27 changed files with 779 additions and 1089 deletions

View File

@@ -1,5 +1,4 @@
sphinx
breathe
sphinx-book-theme
sphinx-copybutton
mlx

View File

@@ -18,7 +18,6 @@ release = version
# -- General configuration ---------------------------------------------------
extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = tree_flatten(model.parameters(), destination={})
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -491,27 +491,19 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:

View File

@@ -24,7 +24,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/steel_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
@@ -54,10 +53,10 @@ target_sources(
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)

View File

@@ -4,16 +4,16 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core {
namespace mlx::core::cu {
void CublasGemm::run_batched(
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -22,7 +22,7 @@ void CublasGemm::run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
execute(
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -33,16 +33,16 @@ void CublasGemm::run_batched(
}
}
void CublasGemm::run_batched(
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
encoder.set_input_array(a);
@@ -56,7 +56,7 @@ void CublasGemm::run_batched(
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
execute(
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -70,4 +70,4 @@ void CublasGemm::run_batched(
}
}
} // namespace mlx::core
} // namespace mlx::core::cu

View File

@@ -0,0 +1,208 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -7,12 +7,10 @@
#include <fmt/format.h>
namespace mlx::core {
namespace {
namespace mlx::core::cu {
struct CublasPreference {
CublasPreference(cu::Device& device) {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
@@ -35,7 +33,7 @@ struct CublasPreference {
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
@@ -54,7 +52,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
@@ -72,7 +70,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
@@ -104,10 +102,8 @@ cublasLtMatrixLayout_t create_matrix_layout(
return desc;
}
} // namespace
CublasGemm::CublasGemm(
cu::Device& device,
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -159,8 +155,8 @@ CublasGemm::CublasGemm(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
CublasGemm::CublasGemm(
cu::Device& device,
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -175,7 +171,7 @@ CublasGemm::CublasGemm(
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: CublasGemm(
: Matmul(
device,
dtype,
a_transposed,
@@ -194,7 +190,7 @@ CublasGemm::CublasGemm(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
CublasGemm::~CublasGemm() {
Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
@@ -202,73 +198,7 @@ CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha,
beta);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
alpha,
beta);
}
void CublasGemm::execute(
void Matmul::run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -326,4 +256,29 @@ void CublasGemm::execute(
encoder.stream()));
}
} // namespace mlx::core
void Matmul::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h"
#include <cublasLt.h>
#include <optional>
namespace mlx::core {
class CublasGemm {
namespace mlx::core::cu {
class Matmul {
public:
CublasGemm(
cu::Device& device,
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -25,8 +25,8 @@ class CublasGemm {
int64_t a_batch_stride,
int64_t b_batch_stride);
CublasGemm(
cu::Device& device,
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -42,54 +42,41 @@ class CublasGemm {
int64_t b_batch_stride,
int64_t c_batch_stride);
~CublasGemm();
~Matmul();
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const std::optional<array>& c = std::nullopt,
float alpha = 1,
float beta = 0);
void run(
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
void execute(
void run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -110,4 +97,4 @@ class CublasGemm {
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace mlx::core
} // namespace mlx::core::cu

View File

@@ -1,327 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int NDIM>
__global__ void set_mm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_mm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
template <int NDIM>
__global__ void set_addmm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
} // namespace cu
namespace {
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
} // namespace
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
{batch_count * 3},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_mm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_mm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{batch_count * 4},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
const_param<ndim_constant()>(c_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core

View File

@@ -1,301 +0,0 @@
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/gemms/steel_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <numeric>
#include <cooperative_groups.h>
#include "mlx/backend/cuda/steel/gemm.cuh"
#include "mlx/backend/cuda/steel/mma.cuh"
#include "mlx/backend/cuda/steel/tiles.cuh"
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
struct GemmParams {
int M;
int N;
int K;
int lda;
int ldb;
int ldd;
int NblockM;
int NblockN;
int NblockK;
};
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int SL,
int Nstages>
__global__ void kernel_steel_gemm(
const T* a,
const T* b,
T* d,
__grid_constant__ const GemmParams params) {
const int bM_idx = (blockIdx.y << SL) + (blockIdx.x & ((1 << SL) - 1));
const int bN_idx = blockIdx.x >> SL;
if (params.NblockN <= bN_idx || params.NblockM <= bM_idx) {
return;
}
const int d_row = bM_idx * BM;
const int d_col = bN_idx * BN;
const size_t d_row_long = size_t(d_row);
const size_t d_col_long = size_t(d_col);
a += transpose_a ? d_row_long : d_row_long * params.K;
b += transpose_b ? d_col_long * params.K : d_col_long;
d += d_row_long * params.ldd + d_col_long;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
const int wm = warp_idx / WN;
const int wn = warp_idx % WN;
constexpr int SM = BM / WM;
constexpr int SN = BN / WN;
constexpr int SK = BK;
constexpr int TK = SK / 16;
constexpr int NUM_WARPS = WM * WN;
// Allocate shared memory
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&as)[Nstages] =
*(SharedTile<T, BM, BK>(*)[Nstages])(&shmem[0]);
SharedTile<T, BN, BK>(&bs)[Nstages] = *(SharedTile<T, BN, BK>(*)[Nstages])(
&shmem[sizeof(T) * Nstages * BM * BK]);
// Allocate registers for the MMA
RegisterTile<float, SM, SN> C;
RegisterTile<T, SM, 16> A[TK];
RegisterTile<T, SN, 16> B[TK];
// Zero the accumulators
C.fill(0);
// Start gmem -> smem copies
int k_block_read = 0;
MLX_UNROLL
for (int bk = 0; bk < (Nstages - 1); bk++) {
load_async<NUM_WARPS>(
as[bk], as[bk].base_addr(), a + k_block_read, params.K);
load_async<NUM_WARPS>(
bs[bk], bs[bk].base_addr(), b + k_block_read, params.K);
k_block_read += BK;
cp_async_commit();
}
int smem_pipe_read = 0;
int smem_pipe_write = Nstages - 1;
// Wait till only 1 remains laoding
cp_async_wait<1>();
block.sync();
const int offset_m = wm * SM;
const int offset_n = wn * SN;
// Start smem -> register copy
A[0].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
lane_idx / 16 * 8);
B[0].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
lane_idx / 16 * 8);
// Main loop
for (int kb = 0; kb < params.NblockK; kb++) {
// Prepare next registers
{
A[1].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
16 + lane_idx / 16 * 8);
B[1].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
16 + lane_idx / 16 * 8);
}
// Prepare next smem
if ((kb + Nstages - 1) < params.NblockK) {
load_async<NUM_WARPS>(
as[smem_pipe_write],
as[smem_pipe_write].base_addr(),
a + k_block_read,
params.K);
load_async<NUM_WARPS>(
bs[smem_pipe_write],
bs[smem_pipe_write].base_addr(),
b + k_block_read,
params.K);
}
k_block_read += BK;
cp_async_commit();
smem_pipe_write = smem_pipe_read;
smem_pipe_read = smem_pipe_read + 1;
smem_pipe_read = (smem_pipe_read == Nstages) ? 0 : smem_pipe_read;
// Do current gemm
mma_t(C, A[0], B[0]);
// Do wait for next register
cp_async_wait<1>();
block.sync();
// Prepare next register (smem_pipe_read has moved to the next)
{
A[0].load(
as[smem_pipe_read],
as[smem_pipe_read].base_addr(),
offset_m + lane_idx % 16,
lane_idx / 16 * 8);
B[0].load(
bs[smem_pipe_read],
bs[smem_pipe_read].base_addr(),
offset_n + lane_idx % 16,
lane_idx / 16 * 8);
}
// Do current gemm
mma_t(C, A[1], B[1]);
}
// Wait and clear
cp_async_wait_all();
block.sync();
C.store_global(d, params.ldd, offset_m, offset_n);
}
} // namespace cu
void dispatch_steel_gemm(
const Stream& s,
cu::CommandEncoder& encoder,
const array& a,
const array& b,
array& d,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool a_transposed,
bool b_transposed) {
using DataType = cuda_type_t<float16_t>;
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(d);
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int SL = 0;
constexpr int Nstages = 3;
constexpr uint32_t smem_bytes = BK * (BM + BN) * Nstages * sizeof(DataType);
const int NblockM = (M + BM - 1) / BM;
const int NblockN = (N + BN - 1) / BN;
const int NblockK = (K + BK - 1) / BK;
cu::GemmParams params{
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* int NblockM = */ NblockM,
/* int NblockN = */ NblockN,
/* int NblockK = */ NblockK,
};
// Prepare launch grid params
int tile = 1 << SL;
int tm = (NblockM + tile - 1) / tile;
int tn = NblockN * tile;
dim3 grid_dim(tn, tm, 1);
dim3 block_dim(32 * WM * WN, 1, 1);
dispatch_bool(a_transposed, [&](auto ta_) {
dispatch_bool(b_transposed, [&](auto tb_) {
constexpr bool ta = ta_.value;
constexpr bool tb = tb_.value;
auto kernel = cu::ab_t_aligned<DataType, BM, BN, BK>;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
smem_bytes,
a.data<DataType>(),
b.data<DataType>(),
d.data<DataType>(),
N,
K);
// auto kernel = cu::kernel_steel_gemm<DataType, BM, BN, BK, WM, WN, ta,
// tb, SL, Nstages>;
// cudaFuncSetAttribute(kernel,
// cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
// encoder.add_kernel_node(
// kernel,
// grid_dim,
// block_dim,
// smem_bytes,
// a.data<DataType>(),
// b.data<DataType>(),
// d.data<DataType>(),
// params);
});
});
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
#pragma once
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <numeric>
namespace mlx::core {
void dispatch_steel_gemm(
const Stream& s,
cu::CommandEncoder& encoder,
const array& a,
const array& b,
array& d,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool a_transposed,
bool b_transposed);
} // namespace mlx::core

View File

@@ -7,8 +7,6 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
#include "mlx/backend/cuda/gemms/steel_gemm.h"
#include <nvtx3/nvtx3.hpp>
#include <numeric>
@@ -97,27 +95,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
if (out.dtype() == float16 && batch_count == 1 && !a_transposed &&
b_transposed) {
return dispatch_steel_gemm(
/* const Stream& s = */ s,
/* cu::CommandEncoder& encoder = */ encoder,
/* const array& a = */ a,
/* const array& b = */ b,
/* array& d = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ N,
/* bool a_transposed = */ a_transposed,
/* bool b_transposed = */ b_transposed);
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::Matmul matmul(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -131,7 +111,14 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -199,7 +186,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::Matmul matmul(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -215,7 +202,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back(),
c_batch_strides.back());
gemm.run(
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
}
matmul.run_batched(
encoder,
out,
a,

View File

@@ -8,13 +8,19 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace fe = cudnn_frontend;
namespace mlx::core {
namespace cu {
@@ -639,6 +645,294 @@ void sdpa_vector_fallback(
}
}
struct SDPACacheKey {
int device_id;
fe::DataType_t cudnn_type;
int B;
int H;
int D;
int qL;
int kL;
int gqa_factor;
float scale;
int64_t Q_strides[3];
int64_t K_strides[3];
int64_t V_strides[3];
int64_t O_strides[3];
bool generate_stats;
bool causal_mask;
};
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, std::shared_ptr<fe::graph::Graph>>
cache(
/* capacity */ 128);
return cache;
}
#define Q_UID 1
#define K_UID 2
#define V_UID 3
#define O_UID 4
#define STATS_UID 5
std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
cu::CommandEncoder& encoder,
const SDPACacheKey& cache_key) {
// Check if graph has already been fully built
if (auto it = sdpa_cache().find(cache_key); it != sdpa_cache().end()) {
return it->second;
}
// Set up new graph
auto graph = std::make_shared<fe::graph::Graph>();
graph->set_io_data_type(cache_key.cudnn_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_uid(Q_UID)
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
.set_stride(
{cache_key.Q_strides[0],
cache_key.Q_strides[1],
cache_key.Q_strides[2],
1}));
int h_kv = cache_key.H / cache_key.gqa_factor;
auto K =
graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_uid(K_UID)
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
.set_stride(
{cache_key.K_strides[0],
cache_key.K_strides[1],
cache_key.V_strides[2],
1}));
auto V =
graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_uid(V_UID)
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
.set_stride(
{cache_key.V_strides[0],
cache_key.V_strides[1],
cache_key.V_strides[2],
1}));
auto sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(!cache_key.generate_stats)
.set_attn_scale(cache_key.scale);
if (cache_key.causal_mask && cache_key.qL > 1) {
sdpa_options.set_diagonal_alignment(fe::DiagonalAlignment_t::TOP_LEFT)
.set_diagonal_band_right_bound(0);
}
auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);
O->set_output(true)
.set_uid(O_UID)
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
.set_stride(
{cache_key.O_strides[0],
cache_key.O_strides[1],
cache_key.O_strides[2],
1});
if (cache_key.generate_stats) {
Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
.set_uid(STATS_UID);
}
// Build and Validate cudnn graph
auto handle = encoder.device().cudnn_handle();
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
if (cudnnGetVersion() < 90600) {
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
if (!build_status.is_good()) {
throw std::runtime_error(
"Unable to build cudnn graph for attention."
" Failed with message: " +
build_status.get_message());
}
} else {
auto val_status = graph->validate();
auto op_status = graph->build_operation_graph(handle);
auto plan_stauts =
graph->create_execution_plans({cudnn_frontend::HeurMode_t::A});
if (!plan_stauts.is_good()) {
throw std::runtime_error(
"Unable to create exec plan for cudnn attention."
" Failed with message: " +
plan_stauts.get_message());
}
graph->select_behavior_notes(
{cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
auto support_status = graph->check_support(handle);
if (!support_status.is_good()) {
throw std::runtime_error(
"No cuda graph support for cudnn attention."
" Failed with message: " +
support_status.get_message());
}
auto build_status = graph->build_plans(handle);
if (!build_status.is_good()) {
throw std::runtime_error(
"Unable to build cudnn graph for attention."
" Failed with message: " +
build_status.get_message());
}
}
auto [it, _] = sdpa_cache().emplace(cache_key, graph);
return it->second;
}
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return fe::DataType_t::INT8;
case int32:
return fe::DataType_t::INT32;
case uint8:
return fe::DataType_t::UINT8;
case float16:
return fe::DataType_t::HALF;
case bfloat16:
return fe::DataType_t::BFLOAT16;
case float32:
return fe::DataType_t::FLOAT;
case float64:
return fe::DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in SDPA: {}.", dtype_to_string(dtype)));
}
}
void sdpa_cudnn(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
auto cudnn_type = dtype_to_cudnn_type(q.dtype());
int B = q.shape(0);
int H = q.shape(1);
int D = q.shape(3);
int gqa_factor = q.shape(1) / k.shape(1);
int qL = q.shape(2);
int kL = k.shape(2);
SDPACacheKey cache_key{
/* int device_id = */ encoder.device().cuda_device(),
/* fe::DataType_t cudnn_type = */ cudnn_type,
/* int B = */ B,
/* int H = */ H,
/* int D = */ D,
/* int qL = */ qL,
/* int kL = */ kL,
/* int gqa_factor = */ gqa_factor,
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)},
/* bool generate_stats = */ false,
/* bool causal_mask = */ do_causal_};
auto graph = get_sdpa_forward_graph(encoder, cache_key);
int64_t workspace_size = 0;
auto workspace_status = graph->get_workspace_size(workspace_size);
if (!workspace_status.is_good()) {
throw std::runtime_error("Unable to get workspace for cudnn attention.");
}
array workspace(
allocator::malloc(workspace_size), {int(workspace_size)}, uint8);
auto workspace_ptr = workspace.data<void>();
std::unordered_map<int64_t, void*> variant_pack = {
{Q_UID, const_cast<void*>(q.data<void>())},
{K_UID, const_cast<void*>(k.data<void>())},
{V_UID, const_cast<void*>(v.data<void>())},
{O_UID, o.data<void>()}};
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
if (cudnnGetVersion() < 90600) {
auto capture = encoder.capture_context();
auto exec_status = graph->execute(handle, variant_pack, workspace_ptr);
if (!exec_status.is_good()) {
capture.discard = true;
throw std::runtime_error(
"Unable to execute cudnn attention."
" Failed with message: " +
exec_status.get_message());
}
} else {
cudaGraph_t cu_graph;
cudaGraphCreate(&cu_graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&cu_graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
auto cu_graph_status = graph->populate_cuda_graph(
handle, variant_pack, workspace_ptr, cu_graph);
if (!cu_graph_status.is_good()) {
throw std::runtime_error(
"Unable to add cuda graph for cudnn attention."
" Failed with message: " +
cu_graph_status.get_message());
}
encoder.add_graph_node(cu_graph);
}
encoder.add_temporary(workspace);
}
} // namespace
namespace fast {
@@ -651,9 +945,6 @@ bool ScaledDotProductAttention::use_fallback(
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
@@ -669,7 +960,15 @@ bool ScaledDotProductAttention::use_fallback(
const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_config = supported_vector_config;
auto& cu_device = cu::device(s.device);
const bool supported_matrix_config = query_sequence_length > 4 &&
cu_device.compute_capability_major() >= 8 &&
query_sequence_length == key_sequence_length &&
(q.dtype() == float16 || q.dtype() == bfloat16);
const bool supported_config =
(supported_matrix_config || supported_vector_config);
return has_arr_mask || !supported_config;
}
@@ -703,6 +1002,10 @@ void ScaledDotProductAttention::eval_gpu(
}
};
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
@@ -756,7 +1059,7 @@ void ScaledDotProductAttention::eval_gpu(
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
@@ -770,9 +1073,35 @@ void ScaledDotProductAttention::eval_gpu(
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
}
// Full attention mode should never reach here
// Full attention mode
else {
throw std::runtime_error("Doesn't support matrix yet.");
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ 0,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
}
}

View File

@@ -143,87 +143,85 @@ struct Tile16x16 {
}
};
// /**
// * A simple container of multiple Tile16x16.
// *
// * Provides utility functions for loading and manipulating collections of
// basic
// * tiles.
// */
// template <typename T, int ROWS_, int COLS_>
// struct RegisterTile {
// static constexpr int ROWS = ROWS_;
// static constexpr int COLS = COLS_;
// static constexpr int TILES_X = COLS / 16;
// static constexpr int TILES_Y = ROWS / 16;
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
// Tile16x16<T> data[TILES_X * TILES_Y];
Tile16x16<T> data[TILES_X * TILES_Y];
// __device__ inline void fill(T v) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// data[i * TILES_X + j].fill(v);
// }
// }
// }
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
// template <typename Tile>
// __device__ __forceinline__ void
// load(Tile& tile, uint32_t base_address, int row, int col) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// data[i * TILES_X + j].load(
// tile.loc(base_address, row + i * 16, col + j * 16));
// }
// }
// }
template <typename Tile>
__device__ __forceinline__ void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
// template <typename Tile, typename F>
// __device__ __forceinline__ void
// load(Tile& tile, F f, uint32_t base_address, int row, int col) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// f(data[i * TILES_X + j],
// tile,
// base_address,
// row + i * 16,
// col + j * 16);
// }
// }
// }
template <typename Tile, typename F>
__device__ __forceinline__ void
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
f(data[i * TILES_X + j],
tile,
base_address,
row + i * 16,
col + j * 16);
}
}
}
// template <typename U>
// __device__ inline void store_global(U* x, int N, int row, int col) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// data[i * TILES_X + j].store_global(
// x + (row + i * 16) * N + col + j * 16, N);
// }
// }
// }
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
// template <typename U>
// __device__ inline void
// store_global_safe(U* x, int N, int row, int col, int max_rows) {
// MLX_UNROLL
// for (int i = 0; i < TILES_Y; i++) {
// MLX_UNROLL
// for (int j = 0; j < TILES_X; j++) {
// data[i * TILES_X + j].store_global_safe(
// x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i *
// 16);
// }
// }
// }
// };
template <typename U>
__device__ inline void
store_global_safe(U* x, int N, int row, int col, int max_rows) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global_safe(
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
}
}
}
};
/**
* A simple container of multiple Tile16x16.

View File

@@ -104,7 +104,7 @@ struct CommandEncoder {
};
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*>& outputs() {
std::unordered_set<const void*> outputs() {
return all_outputs_;
};

View File

@@ -134,10 +134,6 @@ instantiate_and_or(and, And)
instantiate_and_or(or, Or)
#define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \

View File

@@ -247,25 +247,15 @@ std::pair<Dtype, Dtype> remap_reduce_types(
const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) {
switch (in.dtype()) {
case uint8:
return {uint8, uint32};
case uint16:
return {uint16, uint32};
case uint32:
return {uint32, uint32};
case uint64:
return {uint64, uint64};
case int8:
switch (in.dtype().size()) {
case 1:
return {int8, int32};
case int16:
case 2:
return {int16, int32};
case int32:
case 4:
return {int32, int32};
case int64:
case 8:
return {int64, int64};
default:
throw std::runtime_error("Unsupported integer type");
}
}
if (in.dtype() == bool_) {

View File

@@ -2381,20 +2381,9 @@ array logsumexp(
throw std::invalid_argument(
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
}
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && reduce_last_dim) {
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
auto out_shape = a.shape();
out_shape.back() = 1;
@@ -3414,20 +3403,10 @@ array softmax(
throw std::invalid_argument(
"[softmax] Received non-empty axes for array with 0 dimensions.");
}
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && reduce_last_dim) {
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),

View File

@@ -3,8 +3,8 @@
#pragma once
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 28
#define MLX_VERSION_PATCH 0
#define MLX_VERSION_MINOR 27
#define MLX_VERSION_PATCH 1
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -2,6 +2,6 @@
requires = [
"setuptools>=80",
"nanobind==2.4.0",
"cmake>=3.25,<4.1",
"cmake>=3.25",
]
build-backend = "setuptools.build_meta"

View File

@@ -178,7 +178,7 @@ class Module(dict):
if strict:
new_weights = dict(weights)
curr_weights = tree_flatten(self.parameters(), destination={})
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras)
extras = ",\n".join(sorted(extras))
@@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors`
"""
params_dict = tree_flatten(self.parameters(), destination={})
params_dict = dict(tree_flatten(self.parameters()))
if file.endswith(".npz"):
mx.savez(file, **params_dict)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple
def tree_map(
@@ -114,11 +114,8 @@ def tree_map_with_path(
def tree_flatten(
tree: Any,
prefix: str = "",
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
) -> Any:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -131,12 +128,9 @@ def tree_flatten(
print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]
print(tree_flatten([[[0]]], prefix=".hello"))
print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note::
Dictionaries should have keys that are valid Python identifiers.
@@ -146,50 +140,26 @@ def tree_flatten(
always discarded.
is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns:
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
the Python tree.
List[Tuple[str, Any]]: The flat representation of the Python tree.
"""
if destination is None:
destination = []
flat_tree = []
# Create the function to update the destination. We are taking advantage of
# the fact that list.extend and dict.update have the same API to simplify
# the code a bit.
if isinstance(destination, list):
_add_to_destination = destination.extend
elif isinstance(destination, dict):
_add_to_destination = destination.update
else:
raise ValueError("Destination should be either a list or a dictionary or None")
if is_leaf is None or not is_leaf(tree):
if isinstance(tree, (list, tuple)):
for i, t in enumerate(tree):
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
return flat_tree
if isinstance(tree, dict):
for k, t in tree.items():
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
return flat_tree
# Leaf identified by is_leaf so add it and return
if is_leaf is not None and is_leaf(tree):
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
return destination
# Dictionary so recursively add each subtree
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
return destination
# Leaf so add it and return
_add_to_destination([(prefix[1:], tree)])
return destination
return [(prefix[1:], tree)]
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python
@@ -200,34 +170,31 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
print(d)
# {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args:
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A Python tree.
"""
items = tree.items() if isinstance(tree, dict) else tree
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
# Special case when we have just one element in the tree ie not a tree
if len(items) == 1:
key, value = next(iter(items))
if key == "":
return value
try:
int(tree[0][0].split(".", maxsplit=1)[0])
is_list = True
except ValueError:
is_list = False
# collect children
children = defaultdict(list)
for key, value in items:
for key, value in tree:
current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value))
# Assume they are a list and fail to dict if the keys are not all integers
try:
# recursively map them to the original container
if is_list:
keys = sorted((int(idx), idx) for idx in children.keys())
l = []
for i, k in keys:
@@ -235,7 +202,7 @@ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k]))
return l
except ValueError:
else:
return {k: tree_unflatten(v) for k, v in children.items()}

View File

@@ -80,7 +80,7 @@ class TestBase(mlx_tests.MLXTestCase):
self.weights = {"w1": mx.zeros((2, 2)), "w2": mx.ones((2, 2))}
model = DictModule()
params = tree_flatten(model.parameters(), destination={})
params = dict(tree_flatten(model.parameters()))
self.assertEqual(len(params), 2)
self.assertTrue(mx.array_equal(params["weights.w1"], mx.zeros((2, 2))))
self.assertTrue(mx.array_equal(params["weights.w2"], mx.ones((2, 2))))

View File

@@ -155,19 +155,6 @@ TEST_CASE("test gpu reduce") {
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
}
// sum and prod overflow
{
auto a = full({256, 2, 2}, 1u, uint8);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
a = full({65535, 2, 2}, 1u, uint16);
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
}
}
TEST_CASE("test gpu reduce with axes") {
// reducing only some axes and irregular layouts
{
array a(1.0f);

View File

@@ -915,23 +915,6 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
}
// Test unsigned sum
{
const int num_elems = 1000;
auto x = astype(full({num_elems}, 255), uint8);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
x = astype(full({num_elems}, 65535), uint16);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
x = full({3, 3, 3}, 10000, uint32);
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
x = full({3, 3, 3}, 10000, uint64);
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
}
// Test prod
{
auto x = array({});
@@ -964,21 +947,6 @@ TEST_CASE("test reduction ops") {
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
}
// Test unsigned prod
{
auto x = array({255, 255}, {2}, uint8);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
x = array({65535, 2}, {2}, uint16);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
x = array({100000, 2}, {2}, uint32);
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
x = array({100000, 2}, {2}, uint64);
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
}
// Test all
{
auto x = array({});