Compare commits

...

21 Commits

Author SHA1 Message Date
Awni Hannun
cea9369610 fix lapack svd (#2515) 2025-08-18 15:07:59 -07:00
Awni Hannun
e7c6e1db82 no segfault with uninitialized array.at (#2514) 2025-08-18 08:33:38 -07:00
Awni Hannun
c5fcd5b61b fix custom kernel test (#2510) 2025-08-18 06:45:59 -07:00
Angelos Katharopoulos
1df9887998 Ensure no oob read in gemv_masked (#2508) 2025-08-17 08:42:33 -07:00
Angelos Katharopoulos
73f22d6226 Ensure small sort doesn't use indices if not argsort (#2506) 2025-08-17 08:42:20 -07:00
Cheng
c422050ca7 Update cuDNN Frontend to v1.14 (#2505) 2025-08-17 19:13:01 +09:00
Cheng
1ba18ff7d9 [CUDA] Fix conv grads with groups (#2495)
* Put reshape utils in one file

* [CUDA] Fix conv grads with groups

* Put the reshape utils in gpu/copy.h
2025-08-16 10:09:18 +09:00
Cheng
37b440faa8 Clean up code handling both std::vector and SmallVector (#2493) 2025-08-16 09:01:10 +09:00
Cheng
888b13ed63 Remove the hack around SmallVector in cpu compile (#2494) 2025-08-16 08:17:24 +09:00
Cheng
4abb218d21 The naive_conv_2d is no longer used (#2496) 2025-08-16 07:57:30 +09:00
Awni Hannun
6441c21a94 Faster general unary op (#2472)
* faster general unary op

* faster general ops + reorg

* fix + comment

* binary two

* copy general
2025-08-15 15:04:12 -07:00
Cheng
dfb5022eab Rename cu::Matmul to CublasGemm (#2488) 2025-08-13 09:37:40 +09:00
Daniel Yeh
ac207ce7aa make code blocks copyable (#2480)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-08-12 12:29:02 -07:00
Abe Leininger
fce53b61d6 Fix reduce sum/prod overflow (#2477) 2025-08-12 00:05:33 -07:00
Angelos Katharopoulos
8ae4a76308 Use CMake <4.1 to avoid the nvpl error (#2489) 2025-08-12 00:03:42 -07:00
Cheng
7fde1b6a1e Fix logsumexp/softmax not fused for some cases (#2474) 2025-08-08 14:07:17 -07:00
Cheng
aa7b47481a [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) 2025-08-08 15:23:30 +09:00
Awni Hannun
56be773610 version (#2470) 2025-08-07 00:36:04 -07:00
Jagrit Digani
a9bdd67baa Add CUDA sdpa vector (#2468) 2025-08-06 21:40:26 -07:00
Angelos Katharopoulos
f2adb5638d Fix typo in metal command encoder (#2471) 2025-08-06 16:58:23 -07:00
Luca Vivona
728d4db582 Support destination arg in tree flatten/unflatten (#2450) 2025-08-06 15:34:59 -07:00
107 changed files with 2899 additions and 920 deletions

View File

@@ -1,4 +1,5 @@
sphinx
breathe
sphinx-book-theme
sphinx-copybutton
mlx

View File

@@ -18,6 +18,7 @@ release = version
# -- General configuration ---------------------------------------------------
extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",

View File

@@ -128,6 +128,7 @@ relying on a copy from ``ensure_row_contiguous``:
input_names=["inp"],
output_names=["out"],
source=source
ensure_row_contiguous=False,
)
def exp_elementwise(a: mx.array):
@@ -138,7 +139,6 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = dict(tree_flatten(model.parameters()))
params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -228,31 +228,4 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core

View File

@@ -196,9 +196,6 @@ void shared_buffer_reshape(
const Strides& out_strides,
array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));

View File

@@ -157,10 +157,12 @@ inline void build_kernel(
#endif
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
os << "void " << kernel_name
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(i)) {
@@ -175,8 +177,8 @@ inline void build_kernel(
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
os << " const int64_t* " << xname << "_strides = strides["
<< strides_index++ << "];" << std::endl;
}
}
@@ -186,10 +188,8 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
// Add output size
if (contiguous) {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
@@ -288,17 +288,8 @@ void Compiled::eval_cpu(
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Force allocating shape/strides on heap so we can take their data() first
// and then std::move them.
// TODO: Refactor code to avoid heap allocation.
shape.grow();
for (auto& s : strides) {
s.grow();
}
// Collect function input arguments.
std::vector<void*> args;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
continue;
@@ -306,9 +297,6 @@ void Compiled::eval_cpu(
const auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
}
// Get the kernel name from the lib
@@ -343,16 +331,20 @@ void Compiled::eval_cpu(
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
if (!contiguous) {
args.push_back((void*)shape.data());
} else {
if (contiguous) {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
encoder.dispatch([fun,
args = std::move(args),
strides = std::move(strides),
shape = std::move(shape)]() mutable { fun(args.data()); });
shape = std::move(shape)]() mutable {
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
}
} // namespace mlx::core

View File

@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)

View File

@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:

View File

@@ -81,9 +81,7 @@ void svd_impl(
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto job_u = (u_ptr) ? "V" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
auto jobz = (u_ptr) ? "A" : "N";
// Will contain the number of singular values after the call has returned.
int ns = 0;
@@ -91,30 +89,20 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info;
// Compute workspace size.
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
@@ -136,20 +124,13 @@ void svd_impl(
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
@@ -167,13 +148,6 @@ void svd_impl(
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
});
encoder.add_temporary(in);

View File

@@ -8,7 +8,6 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
@@ -39,23 +38,26 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -147,7 +149,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.12.1
GIT_TAG v1.14.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)

View File

@@ -0,0 +1,21 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Add)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(ArcTan2)
} // namespace mlx::core

View File

@@ -99,39 +99,89 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -209,39 +259,61 @@ void binary_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] =
get_launch_args(out, large());
auto kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node(
cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
cu::binary_g<Op, InType, OutType, IdxT>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
rest,
const_param(shape),
const_param(a_strides),
const_param(b_strides),
@@ -304,54 +376,4 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, name(), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Divide)
} // namespace mlx::core

View File

@@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Greater)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(GreaterEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Less)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LessEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogAddExp)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalAnd)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalOr)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Maximum)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Minimum)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Multiply)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(NotEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Power)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Remainder)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Subtract)
} // namespace mlx::core

View File

@@ -127,45 +127,99 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_two_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_two_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -225,42 +279,64 @@ void binary_two_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out_a.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] =
get_launch_args(out_a, large());
auto kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node(
cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto [num_blocks, block_dims] =
get_launch_args(out_a, large());
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
cu::binary_two_g<Op, InType, OutType, IdxT>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
rest,
const_param(shape),
const_param(a_strides),
const_param(b_strides),

View File

@@ -7,9 +7,6 @@
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
@@ -336,6 +333,42 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
}
}
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
array group_transpose(
const array& x,
int groups,
int group_dim,
int axis1,
int axis2,
Stream s) {
if (groups == 1) {
return swapaxes_in_eval(x, axis1, axis2);
}
int ndim = x.ndim();
if (group_dim < 0) {
group_dim += ndim;
}
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
if (group_dim <= axis1) {
axis1 += 1;
}
if (group_dim <= axis2) {
axis2 += 1;
}
auto shape = x.shape();
shape.insert(shape.begin() + group_dim, groups);
shape[group_dim + 1] = shape[group_dim + 1] / groups;
array x_trans = reshape_in_eval(x, std::move(shape), s);
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
return x_trans;
}
// Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies.
@@ -345,13 +378,14 @@ std::tuple<array, array, array> prepare_args(
array in,
array wt,
array out,
int groups,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = swapaxes_in_eval(wt, 0, -1);
wt = group_transpose(wt, groups, 0, 0, -1, s);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = swapaxes_in_eval(in, 0, -1);
in = group_transpose(in, groups, -1, 0, -1, s);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
@@ -457,7 +491,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!execute_plan(encoder, plan, x, w, y)) {
@@ -490,7 +525,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
std::optional<cudnn_frontend::OperationGraph> op_graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, s);
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,

View File

@@ -10,37 +10,80 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
__global__ void copy_gg_nd(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index, shape.data(), strides_in.data(), strides_out.data());
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto in_stride_x = strides_in[NDIM - 1];
auto out_stride_x = strides_out[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data());
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
}
template <typename In, typename Out, typename IdxT>
template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_gg(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto in_stride_x = strides_in[ndim - 1];
auto out_stride_x = strides_out[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data(),
ndim);
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
}
} // namespace cu
@@ -69,33 +112,52 @@ void copy_general(
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = data_size / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto [num_blocks, block_dims] =
get_launch_args(data_size, shape, out.strides(), large());
auto kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 4>;
}
encoder.add_kernel_node(
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in_ptr,
out_ptr,
data_size,
rest,
const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] =
get_launch_args(data_size, shape, out.strides(), large());
auto kernel = cu::copy_gg<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::copy_gg<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
cu::copy_gg<InType, OutType, IdxT>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in_ptr,
out_ptr,
data_size,
rest,
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@@ -10,33 +10,67 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
__global__ void copy_g_nd(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
out[index] = CastOp<In, Out>{}(in[idx_in]);
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto stride_x = strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename In, typename Out, typename IdxT>
template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_g(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
} // namespace cu
@@ -61,30 +95,49 @@ void copy_general_input(
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
encoder.add_kernel_node(
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
cu::copy_g<InType, OutType, IdxT>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
rest,
const_param(shape),
const_param(strides_in),
ndim);

View File

@@ -146,6 +146,23 @@ inline __device__ void store_vector(
}
}
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size,
int64_t stride) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[stride * (offset * N + i)] = vec[i];
}
}
}
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////

View File

@@ -1,208 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -7,10 +7,12 @@
#include <fmt/format.h>
namespace mlx::core::cu {
namespace mlx::core {
namespace {
struct CublasPreference {
CublasPreference(Device& device) {
CublasPreference(cu::Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
@@ -33,7 +35,7 @@ struct CublasPreference {
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
@@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
}
}
@@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
}
}
@@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
return desc;
}
Matmul::Matmul(
Device& device,
} // namespace
CublasGemm::CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -155,8 +159,8 @@ Matmul::Matmul(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
Matmul::Matmul(
Device& device,
CublasGemm::CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -171,7 +175,7 @@ Matmul::Matmul(
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: Matmul(
: CublasGemm(
device,
dtype,
a_transposed,
@@ -190,7 +194,7 @@ Matmul::Matmul(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
Matmul::~Matmul() {
CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
@@ -198,7 +202,73 @@ Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void Matmul::run_impl(
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha,
beta);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
alpha,
beta);
}
void CublasGemm::execute(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -256,29 +326,4 @@ void Matmul::run_impl(
encoder.stream()));
}
void Matmul::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu
} // namespace mlx::core

View File

@@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h"
#include <cublasLt.h>
#include <optional>
namespace mlx::core::cu {
class Matmul {
namespace mlx::core {
class CublasGemm {
public:
Matmul(
Device& device,
CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -25,8 +25,8 @@ class Matmul {
int64_t a_batch_stride,
int64_t b_batch_stride);
Matmul(
Device& device,
CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -42,25 +42,39 @@ class Matmul {
int64_t b_batch_stride,
int64_t c_batch_stride);
~Matmul();
~CublasGemm();
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c = std::nullopt,
float alpha = 1,
float beta = 0);
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
@@ -68,15 +82,14 @@ class Matmul {
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_impl(
void execute(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -97,4 +110,4 @@ class Matmul {
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace mlx::core::cu
} // namespace mlx::core

View File

@@ -4,16 +4,16 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core::cu {
namespace mlx::core {
void Matmul::run_batched(
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -22,7 +22,7 @@ void Matmul::run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -33,16 +33,16 @@ void Matmul::run_batched(
}
}
void Matmul::run_batched(
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
encoder.set_input_array(a);
@@ -56,7 +56,7 @@ void Matmul::run_batched(
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -70,4 +70,4 @@ void Matmul::run_batched(
}
}
} // namespace mlx::core::cu
} // namespace mlx::core

View File

@@ -0,0 +1,327 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int NDIM>
__global__ void set_mm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_mm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
template <int NDIM>
__global__ void set_addmm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
} // namespace cu
namespace {
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
} // namespace
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
{batch_count * 3},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_mm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_mm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{batch_count * 4},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
const_param<ndim_constant()>(c_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core

View File

@@ -97,7 +97,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -111,14 +111,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -186,7 +179,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -202,12 +195,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back(),
c_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
}
matmul.run_batched(
gemm.run(
encoder,
out,
a,

View File

@@ -6,17 +6,6 @@
namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(CustomKernel)
} // namespace fast

View File

@@ -0,0 +1,781 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
#define PRAGMA_LOOP_UNROLL #pragma unroll
struct AttnParams {
int B;
int H;
int D;
int qL;
int kL;
int gqa_factor;
float scale;
int64_t Q_strides[3];
int64_t K_strides[3];
int64_t V_strides[3];
int64_t O_strides[3];
};
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_1pass(
const T* Q,
const T* K,
const T* V,
T* O,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int v_per_thread = D / BD;
const int inner_k_stride = BN * int(params.K_strides[2]);
const int inner_v_stride = BN * int(params.V_strides[2]);
typedef float U;
U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = warp_idx;
Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence
K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence
// Read the query and 0 the output accumulator
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}
U max_score = -INFINITY;
U sum_exp_score = 0.f;
// For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
score += q[j] * k[j];
}
// Warp sum
score = cg::reduce(warp, score, cg::plus<U>());
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
// Move the pointers to the next kv
K += inner_k_stride;
V += inner_v_stride;
}
if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}
block.sync();
max_score = max_scores[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score =
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
}
}
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_1(
const T* Q,
const T* K,
const T* V,
float* partials,
float* sums,
float* maxs,
__grid_constant__ const AttnParams params) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int blocks = 32;
constexpr int v_per_thread = D / BD;
const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
typedef float U;
U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z / blocks;
const int block_idx = blockIdx.z % blocks;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = block_idx * BN + warp_idx;
Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence
K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s + // Sequence
block_idx; // Block
partials += p_offset * D;
sums += p_offset;
maxs += p_offset;
// Read the query and 0 the output accumulator
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}
U max_score = -1e9;
U sum_exp_score = 0.f;
// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
score += q[j] * k[j];
}
// Warp sum
score = cg::reduce(warp, score, cg::plus<U>());
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
// Move the pointers to the next kv
K += inner_k_stride;
V += inner_v_stride;
}
if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}
block.sync();
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
// Write the sum and new max
if (warp_idx == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
auto ff = exp2f(max_scores[warp_idx] - new_max);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[warp_idx][lane_idx] = o[i] * ff;
block.sync();
if (warp_idx == 0) {
U ot = outputs[0][lane_idx];
PRAGMA_LOOP_UNROLL
for (int j = 1; j < BN; j++) {
ot += outputs[j][lane_idx];
warp.sync();
}
o[i] = ot;
}
block.sync();
}
if (warp_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
partials[v_per_thread * lane_idx + i] = o[i];
}
}
}
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_2(
const float* partials,
const float* sums,
const float* maxs,
T* O,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int blocks = 32;
constexpr int v_per_thread = D / BD;
typedef float U;
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int q_seq_idx = blockIdx.y;
const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s; // Sequence
partials += p_offset * D + warp_idx * D;
sums += p_offset;
maxs += p_offset;
O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence
U max_score = maxs[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = partials[v_per_thread * lane_idx + i];
}
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
}
}
} // namespace cu
namespace {
template <typename F>
void dispatch_headdim(int n, F&& f) {
switch (n) {
case 64:
f(std::integral_constant<int, 64>{});
break;
case 96:
f(std::integral_constant<int, 96>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
void sdpa_vector_1pass_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
/* int D = */ q.shape(3),
/* int qL = */ q.shape(2),
/* int kL = */ k.shape(2),
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
dim3 grid_dim(params.H, params.qL, params.B);
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel =
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
params);
});
});
});
}
void sdpa_vector_2pass_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
/* int D = */ q.shape(3),
/* int qL = */ q.shape(2),
/* int kL = */ k.shape(2),
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
// Allocate the intermediates
int blocks = 32;
Shape intermediate_shape;
intermediate_shape.reserve(o.ndim() + 1);
intermediate_shape.insert(
intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
intermediate_shape.push_back(blocks);
intermediate_shape.push_back(o.shape().back());
array intermediate(intermediate_shape, float32, nullptr, {});
intermediate_shape.pop_back();
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
{
auto kernel = cu::
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(intermediate);
encoder.set_output_array(sums);
encoder.set_output_array(maxs);
dim3 grid_dim(params.H, params.qL, params.B * 32);
dim3 block_dim(8 * 32, 1, 1);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
params);
}
{
auto kernel = cu::
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(intermediate);
encoder.set_input_array(sums);
encoder.set_input_array(maxs);
encoder.set_output_array(o);
dim3 grid_dim(params.H, params.qL, params.B);
dim3 block_dim(1024, 1, 1);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
o.data<DataType>(),
params);
}
});
});
});
}
void sdpa_vector_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
}
}
} // namespace
namespace fast {
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_config = supported_vector_config;
return has_arr_mask || !supported_config;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
std::vector<array> copies;
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy = contiguous_copy_gpu(arr, s);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {
return arr;
}
};
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
} else {
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
}
// Full attention mode should never reach here
else {
throw std::runtime_error("Doesn't support matrix yet.");
}
}
} // namespace fast
} // namespace mlx::core

View File

@@ -1,6 +1,5 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"

View File

@@ -39,52 +39,98 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
}
}
template <typename Op, typename T, typename IdxT, int NDIM>
template <typename Op, typename T, typename IdxT, int NDIM, int N_READS>
__global__ void ternary_g_nd(
const bool* a,
const T* b,
const T* c,
T* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
index,
shape.data(),
a_strides.data(),
b_strides.data(),
c_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
auto c_stride_x = c_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
c_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, false);
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, T(0));
auto c_vec =
load_vector<N_READS>(c + c_idx, index_x, shape_x, c_stride_x, T(0));
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename T, typename IdxT>
template <typename Op, typename T, typename IdxT, int N_READS>
__global__ void ternary_g(
const bool* a,
const T* b,
const T* c,
T* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
const __grid_constant__ Strides c_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx, c_idx] = elem_to_loc(
index,
shape.data(),
a_strides.data(),
b_strides.data(),
c_strides.data(),
ndim);
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
auto c_stride_x = c_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx, c_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
c_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, false);
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, T(0));
auto c_vec =
load_vector<N_READS>(c + c_idx, index_x, shape_x, c_stride_x, T(0));
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
} // namespace cu
@@ -123,36 +169,55 @@ void ternary_op_gpu_inplace(
auto& b_strides = strides[1];
auto& c_strides = strides[2];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel =
cu::ternary_g_nd<Op, DType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::ternary_g_nd<Op, DType, IdxT, dims_constant(), 4>;
}
encoder.add_kernel_node(
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.size(),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides),
const_param<dims_constant()>(c_strides));
});
} else {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel = cu::ternary_g<Op, DType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::ternary_g<Op, DType, IdxT, 4>;
}
encoder.add_kernel_node(
cu::ternary_g<Op, DType, IdxT>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size(),
rest,
const_param(shape),
const_param(a_strides),
const_param(b_strides),

View File

@@ -37,19 +37,36 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
}
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
out[index] = Op{}(in[idx]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -127,8 +144,7 @@ void unary_op_gpu_inplace(
using OutType = cuda_type_t<CTYPE_OUT>;
if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
constexpr int N_READS = 16 / sizeof(OutType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
encoder.add_kernel_node(
@@ -142,18 +158,30 @@ void unary_op_gpu_inplace(
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto [num_blocks, block_dims] = get_launch_args(out, large);
auto ndim = shape.size();
int work_per_thread = 1;
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
encoder.add_kernel_node(
cu::unary_g<Op, InType, OutType, IdxT>,
num_blocks,
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size(),
rest,
const_param(shape),
const_param(strides),
shape.size());
ndim);
}
});
} else {

View File

@@ -0,0 +1,34 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/abs.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccos.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccosh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsin.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsinh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctanh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_invert.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ceil.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conjugate.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cos.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/imag.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log1p.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_not.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/negative.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/real.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/round.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sigmoid.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sign.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sin.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sinh.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sqrt.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/square.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tan.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tanh.cu)

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Abs)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcCos)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcCosh)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcSin)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcSinh)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcTan)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ArcTanh)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(BitwiseInvert)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Ceil)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Conjugate)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Cos)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Cosh)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Erf)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(ErfInv)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Exp)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Expm1)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Floor)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Imag)
} // namespace mlx::core

View File

@@ -0,0 +1,21 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream();
switch (base_) {
case Base::e:
unary_op_gpu<cu::Log>(inputs, out, name(), s);
break;
case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
break;
case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Log1p)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(LogicalNot)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Negative)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Real)
} // namespace mlx::core

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Round::eval_gpu");
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, name(), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Sigmoid)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Sign)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Sin)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Sinh)
} // namespace mlx::core

View File

@@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sqrt::eval_gpu");
auto& s = out.primitive().stream();
if (recip_) {
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
} else {
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Square)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Tan)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/unary/unary.cuh"
namespace mlx::core {
UNARY_GPU(Tanh)
} // namespace mlx::core

View File

@@ -0,0 +1,215 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
std::is_same_v<Op, Sigmoid>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
}
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
return;
}
bool contig = in.flags().contiguous;
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
dispatch_bool(large, [&](auto large) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(OutType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
encoder.add_kernel_node(
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto ndim = shape.size();
int work_per_thread = 1;
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
rest,
const_param(shape),
const_param(strides),
ndim);
}
});
} else {
throw std::runtime_error(fmt::format(
"Can not do unary op {} on input of {} with output of {}.",
op,
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, name(), s); \
}
} // namespace mlx::core

View File

@@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
return arr_copy;
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
int ndim = x.ndim();
if (start_axis < 0) {
start_axis += ndim;
}
if (end_axis < 0) {
end_axis += ndim;
}
start_axis = std::max(0, start_axis);
end_axis = std::min(ndim - 1, end_axis);
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
}
array reshape_in_eval(const array& x, Shape shape, Stream s) {
array out(std::move(shape), x.dtype(), nullptr, {});
reshape_gpu(x, out, s);
return out;
}
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core

View File

@@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
// Copy data from |in| and transpose to |out|'s shape.
void reshape_gpu(const array& in, array& out, Stream s);
// Like the normal ops but safe to call in eval_gpu.
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
array reshape_in_eval(const array& x, Shape shape, Stream s);
array swapaxes_in_eval(const array& x, int axis1, int axis2);
} // namespace mlx::core

View File

@@ -20,29 +20,6 @@
namespace mlx::core {
namespace {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
eval(inputs, out);
@@ -124,7 +101,7 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Flatten::eval_gpu");
reshape(inputs[0], out, stream());
reshape_gpu(inputs[0], out, stream());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -150,7 +127,7 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Reshape::eval_gpu");
reshape(inputs[0], out, stream());
reshape_gpu(inputs[0], out, stream());
}
void Split::eval_gpu(
@@ -224,7 +201,7 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
reshape(inputs[0], out, stream());
reshape_gpu(inputs[0], out, stream());
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -60,22 +60,12 @@ struct CommandEncoder {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
void set_vector_bytes(const Vec& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
// TODO: Code is duplicated but they should be deleted soon.
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
void set_vector_bytes(const Vec& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
@@ -104,7 +94,7 @@ struct CommandEncoder {
};
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*> outputs() {
std::unordered_set<const void*>& outputs() {
return all_outputs_;
};

View File

@@ -166,115 +166,6 @@ instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive conv2d kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid;
(void)simd_lid;
out += tid.z * params.out_strides[0];
in += tid.z * params.in_strides[0];
int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM];
int out_w[TN];
for (int m = 0; m < TM; ++m) {
int mm = (out_hw + m);
out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1];
}
T in_local[TM];
T wt_local[TN];
T out_local[TM * TN] = {T(0)};
for (int h = 0; h < params.wS[0]; ++h) {
for (int w = 0; w < params.wS[1]; ++w) {
for (int c = 0; c < params.C; ++c) {
// Local in
for (int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
: T(0);
}
// Load weight
for (int n = 0; n < TN; ++n) {
int o = out_o + n;
wt_local[n] = o < params.O
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
w * params.wt_strides[2] + c]
: T(0);
}
// Accumulate
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n];
}
}
}
}
}
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
(out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] =
out_local[m * TN + n];
}
}
}
// Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Depthwise convolution kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -262,36 +262,37 @@ struct GEMVKernel {
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
if (leftover > 0) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
// Accumulate results
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
}
@@ -544,31 +545,32 @@ struct GEMVTKernel {
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
if (leftover > 0) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}

View File

@@ -134,6 +134,10 @@ instantiate_and_or(and, And)
instantiate_and_or(or, Or)
#define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \

View File

@@ -45,7 +45,9 @@ struct ThreadSort {
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
if (ARG_SORT) {
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
@@ -111,7 +113,9 @@ struct BlockMergeSort {
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
if (ARG_SORT) {
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
}
b_idx += short(pred);
a_idx += short(!pred);

View File

@@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) {
switch (in.dtype().size()) {
case 1:
switch (in.dtype()) {
case uint8:
return {uint8, uint32};
case uint16:
return {uint16, uint32};
case uint32:
return {uint32, uint32};
case uint64:
return {uint64, uint64};
case int8:
return {int8, int32};
case 2:
case int16:
return {int16, int32};
case 4:
case int32:
return {int32, int32};
case 8:
case int64:
return {int64, int64};
default:
throw std::runtime_error("Unsupported integer type");
}
}
if (in.dtype() == bool_) {

View File

@@ -2381,9 +2381,20 @@ array logsumexp(
throw std::invalid_argument(
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
}
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
if (!is_complex && reduce_last_dim) {
auto dtype = at_least_float(a.dtype());
auto out_shape = a.shape();
out_shape.back() = 1;
@@ -2960,7 +2971,7 @@ array gather(
}
for (auto& x : indices) {
if (x.dtype() == bool_) {
throw("[Gather] Boolean indices not supported.");
throw std::invalid_argument("[Gather] Boolean indices not supported.");
}
}
@@ -3403,10 +3414,20 @@ array softmax(
throw std::invalid_argument(
"[softmax] Received non-empty axes for array with 0 dimensions.");
}
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
if (!is_complex && reduce_last_dim) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),

View File

@@ -440,6 +440,7 @@ class SmallVector {
end_ = begin_;
}
private:
// Grows the backing store by a factor of two, and at least to {min_capacity}.
// TODO: Move to private after removing external code using this method.
MLX_NOINLINE void grow(size_t min_capacity = 0) {
@@ -469,7 +470,6 @@ class SmallVector {
end_of_storage_ = new_storage + new_capacity;
}
private:
MLX_NOINLINE void free_storage() {
std::destroy_n(begin_, end_ - begin_);
if (is_big()) {
@@ -519,6 +519,18 @@ class SmallVector {
std::is_trivially_destructible<T>::value;
};
template <typename>
struct is_vector : std::false_type {};
template <typename T, size_t Size, typename Allocator>
struct is_vector<SmallVector<T, Size, Allocator>> : std::true_type {};
template <typename T, typename Allocator>
struct is_vector<std::vector<T, Allocator>> : std::true_type {};
template <typename Vec>
inline constexpr bool is_vector_v = is_vector<Vec>::value;
#undef MLX_HAS_BUILTIN
#undef MLX_HAS_ATTRIBUTE
#undef MLX_LIKELY

View File

@@ -259,43 +259,6 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os;
}
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
// TODO: Code is duplicated but they should be deleted soon.
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
namespace env {
int get_var(const char* name, int default_value) {

View File

@@ -100,10 +100,6 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v);
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
}
@@ -114,6 +110,19 @@ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v);
}
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
inline std::ostream& operator<<(std::ostream& os, const Vec& v) {
os << "(";
for (auto it = v.begin(); it != v.end(); ++it) {
os << *it;
if (it != std::prev(v.end())) {
os << ",";
}
}
os << ")";
return os;
}
inline bool is_power_of_2(int n) {
return ((n & (n - 1)) == 0) && n != 0;
}

View File

@@ -3,8 +3,8 @@
#pragma once
#define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 27
#define MLX_VERSION_PATCH 1
#define MLX_VERSION_MINOR 28
#define MLX_VERSION_PATCH 0
#define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -2,6 +2,6 @@
requires = [
"setuptools>=80",
"nanobind==2.4.0",
"cmake>=3.25",
"cmake>=3.25,<4.1",
]
build-backend = "setuptools.build_meta"

View File

@@ -178,7 +178,7 @@ class Module(dict):
if strict:
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
curr_weights = tree_flatten(self.parameters(), destination={})
if extras := (new_weights.keys() - curr_weights.keys()):
num_extra = len(extras)
extras = ",\n".join(sorted(extras))
@@ -212,7 +212,7 @@ class Module(dict):
- ``.npz`` will use :func:`mx.savez`
- ``.safetensors`` will use :func:`mx.save_safetensors`
"""
params_dict = dict(tree_flatten(self.parameters()))
params_dict = tree_flatten(self.parameters(), destination={})
if file.endswith(".npz"):
mx.savez(file, **params_dict)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
from collections import defaultdict
from itertools import zip_longest
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
def tree_map(
@@ -114,8 +114,11 @@ def tree_map_with_path(
def tree_flatten(
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
) -> Any:
tree: Any,
prefix: str = "",
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
"""Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
@@ -128,9 +131,12 @@ def tree_flatten(
print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]
print(tree_flatten([[[0]]], ".hello"))
print(tree_flatten([[[0]]], prefix=".hello"))
# [("hello.0.0.0", 0)]
tree_flatten({"a": {"b": 1}}, destination={})
{"a.b": 1}
.. note::
Dictionaries should have keys that are valid Python identifiers.
@@ -140,26 +146,50 @@ def tree_flatten(
always discarded.
is_leaf (callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
destination (list or dict, optional): A list or dictionary to store the
flattened tree. If None an empty list will be used. Default: ``None``.
Returns:
List[Tuple[str, Any]]: The flat representation of the Python tree.
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
the Python tree.
"""
flat_tree = []
if destination is None:
destination = []
if is_leaf is None or not is_leaf(tree):
if isinstance(tree, (list, tuple)):
for i, t in enumerate(tree):
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
return flat_tree
if isinstance(tree, dict):
for k, t in tree.items():
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
return flat_tree
# Create the function to update the destination. We are taking advantage of
# the fact that list.extend and dict.update have the same API to simplify
# the code a bit.
if isinstance(destination, list):
_add_to_destination = destination.extend
elif isinstance(destination, dict):
_add_to_destination = destination.update
else:
raise ValueError("Destination should be either a list or a dictionary or None")
return [(prefix[1:], tree)]
# Leaf identified by is_leaf so add it and return
if is_leaf is not None and is_leaf(tree):
_add_to_destination([(prefix[1:], tree)])
return destination
# List or tuple so recursively add each subtree
if isinstance(tree, (list, tuple)):
for i, item in enumerate(tree):
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
return destination
# Dictionary so recursively add each subtree
if isinstance(tree, dict):
for key, value in tree.items():
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
return destination
# Leaf so add it and return
_add_to_destination([(prefix[1:], tree)])
return destination
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
"""Recreate a Python tree from its flat representation.
.. code-block:: python
@@ -170,31 +200,34 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
print(d)
# {"hello": {"world": 42}}
d = tree_unflatten({"hello.world": 42})
print(d)
# {"hello": {"world": 42}}
Args:
tree (list[tuple[str, Any]]): The flat representation of a Python tree.
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A Python tree.
"""
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
items = tree.items() if isinstance(tree, dict) else tree
try:
int(tree[0][0].split(".", maxsplit=1)[0])
is_list = True
except ValueError:
is_list = False
# Special case when we have just one element in the tree ie not a tree
if len(items) == 1:
key, value = next(iter(items))
if key == "":
return value
# collect children
children = defaultdict(list)
for key, value in tree:
for key, value in items:
current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0]
children[current_idx].append((next_idx, value))
# recursively map them to the original container
if is_list:
# Assume they are a list and fail to dict if the keys are not all integers
try:
keys = sorted((int(idx), idx) for idx in children.keys())
l = []
for i, k in keys:
@@ -202,7 +235,7 @@ def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
l.extend([{} for _ in range(i - len(l))])
l.append(tree_unflatten(children[k]))
return l
else:
except ValueError:
return {k: tree_unflatten(v) for k, v in children.items()}

Some files were not shown because too many files have changed in this diff Show More