mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
8 Commits
simple-gem
...
sdpav-back
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22d0bf273 | ||
|
|
99d8de8445 | ||
|
|
c66b76a8c8 | ||
|
|
f81edd184f | ||
|
|
7f8ba2a003 | ||
|
|
c28249b81a | ||
|
|
e74bcdc5e3 | ||
|
|
d8ed6c1aa3 |
@@ -1,5 +1,4 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
sphinx-copybutton
|
||||
mlx
|
||||
|
||||
@@ -18,7 +18,6 @@ release = version
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
extensions = [
|
||||
"sphinx_copybutton",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
|
||||
@@ -128,7 +128,6 @@ relying on a copy from ``ensure_row_contiguous``:
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
@@ -139,6 +138,7 @@ relying on a copy from ``ensure_row_contiguous``:
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
|
||||
@@ -70,7 +70,6 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/cuda
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
CUDA
|
||||
=====
|
||||
|
||||
.. currentmodule:: mlx.core.cuda
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
@@ -13,4 +13,3 @@ Fast
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
metal_kernel
|
||||
cuda_kernel
|
||||
|
||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state, destination={})
|
||||
mx.save_safetensors("optimizer.safetensors", state)
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
|
||||
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
|
||||
@@ -7,17 +7,17 @@ Exporting Functions
|
||||
|
||||
MLX has an API to export and import functions to and from a file. This lets you
|
||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||
front-end (e.g. C++).
|
||||
front-end (e.g. C++).
|
||||
|
||||
This guide walks through the basics of the MLX export API with some examples.
|
||||
To see the full list of functions check-out the :ref:`API documentation
|
||||
<export>`.
|
||||
|
||||
Basics of Exporting
|
||||
Basics of Exporting
|
||||
-------------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
|
||||
# Both arguments to fun are positional
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||
they are evaluated. The computation graph that gets exported will include
|
||||
the computation that produces enclosed inputs.
|
||||
|
||||
|
||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||
exported function would include the random initialization of the
|
||||
:obj:`mlx.nn.Module` parameters.
|
||||
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
||||
# Set the model's parameters to the input parameters
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = tree_flatten(model.parameters(), destination={})
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
|
||||
# Also ok
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
With ``shapeless=False`` (which is the default), the second call to
|
||||
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
||||
def fun(x, y=None):
|
||||
constant = mx.array(3.0)
|
||||
if y is not None:
|
||||
x += y
|
||||
x += y
|
||||
return x + constant
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
||||
print(out)
|
||||
|
||||
In the above example the function constant data, (i.e. ``constant``), is only
|
||||
saved once.
|
||||
saved once.
|
||||
|
||||
Transformations with Imported Functions
|
||||
---------------------------------------
|
||||
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
||||
# Prints: array(1, dtype=float32)
|
||||
print(dfdx(x))
|
||||
|
||||
# Compile the imported function
|
||||
# Compile the imported function
|
||||
mx.compile(imported_fun)
|
||||
# Prints: array(0, dtype=float32)
|
||||
print(compiled_fun(x)[0])
|
||||
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
||||
// Prints: array(2, dtype=float32)
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||
|
||||
|
||||
@@ -228,4 +228,31 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||
int ndim = x.ndim();
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
|
||||
auto shape = x.shape();
|
||||
std::swap(shape[axis1], shape[axis2]);
|
||||
auto strides = x.strides();
|
||||
std::swap(strides[axis1], strides[axis2]);
|
||||
|
||||
auto [data_size, row_contiguous, col_contiguous] =
|
||||
check_contiguity(shape, strides);
|
||||
bool contiguous = data_size == x.data_size();
|
||||
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(
|
||||
x,
|
||||
std::move(strides),
|
||||
{contiguous, row_contiguous, col_contiguous},
|
||||
x.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -196,6 +196,9 @@ void shared_buffer_reshape(
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
// Like the swapaxes op but safe to call in eval_gpu.
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
template <typename T>
|
||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
|
||||
@@ -157,12 +157,10 @@ inline void build_kernel(
|
||||
#endif
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name
|
||||
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(i)) {
|
||||
@@ -177,8 +175,8 @@ inline void build_kernel(
|
||||
<< "];" << std::endl;
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
os << " const int64_t* " << xname << "_strides = strides["
|
||||
<< strides_index++ << "];" << std::endl;
|
||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,8 +186,10 @@ inline void build_kernel(
|
||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
// Add output size
|
||||
if (contiguous) {
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
||||
} else {
|
||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
|
||||
@@ -288,8 +288,17 @@ void Compiled::eval_cpu(
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Force allocating shape/strides on heap so we can take their data() first
|
||||
// and then std::move them.
|
||||
// TODO: Refactor code to avoid heap allocation.
|
||||
shape.grow();
|
||||
for (auto& s : strides) {
|
||||
s.grow();
|
||||
}
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
@@ -297,6 +306,9 @@ void Compiled::eval_cpu(
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
}
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
@@ -331,20 +343,16 @@ void Compiled::eval_cpu(
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
if (contiguous) {
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)shape.data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable {
|
||||
SmallVector<int64_t*> strides_ptrs;
|
||||
for (auto& s : strides) {
|
||||
strides_ptrs.push_back(s.data());
|
||||
}
|
||||
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||
});
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
|
||||
@@ -491,27 +491,19 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
|
||||
@@ -81,7 +81,9 @@ void svd_impl(
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
auto jobz = (u_ptr) ? "A" : "N";
|
||||
auto job_u = (u_ptr) ? "V" : "N";
|
||||
auto job_vt = (u_ptr) ? "V" : "N";
|
||||
static constexpr auto range = "A";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
@@ -89,20 +91,30 @@ void svd_impl(
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const T ignored_float = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
@@ -124,13 +136,20 @@ void svd_impl(
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in_ptr + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s_ptr + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
@@ -148,6 +167,13 @@ void svd_impl(
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
|
||||
@@ -8,6 +8,7 @@ target_sources(
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
@@ -16,18 +17,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cutlass_gemm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
@@ -50,20 +45,18 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
|
||||
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
|
||||
else()
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
@@ -90,9 +83,6 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# Keep ptx around for inspection
|
||||
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
|
||||
|
||||
# Enable calling host constexpr functions from device. This is needed because
|
||||
# the constexpr version of isnan is host only.
|
||||
target_compile_options(
|
||||
@@ -158,7 +148,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.14.0
|
||||
GIT_TAG v1.12.1
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
@@ -178,12 +168,3 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||
|
||||
# Fetch and make available cutlass
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
|
||||
GIT_TAG v4.1.0)
|
||||
FetchContent_Populate(cutlass)
|
||||
target_include_directories(
|
||||
mlx PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)
|
||||
|
||||
@@ -99,89 +99,39 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename Op,
|
||||
typename In,
|
||||
typename Out,
|
||||
typename IdxT,
|
||||
int NDIM,
|
||||
int N_READS>
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[NDIM - 1];
|
||||
auto a_stride_x = a_strides[NDIM - 1];
|
||||
auto b_stride_x = b_strides[NDIM - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto a_vec =
|
||||
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||
auto b_vec =
|
||||
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto a_stride_x = a_strides[ndim - 1];
|
||||
auto b_stride_x = b_strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index_rest * shape_x,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
ndim);
|
||||
auto a_vec =
|
||||
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||
auto b_vec =
|
||||
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
@@ -259,61 +209,39 @@ void binary_op_gpu_inplace(
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant(),
|
||||
1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant(),
|
||||
4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
rest,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::binary_g<Op, InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
rest,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
@@ -376,4 +304,54 @@ void binary_op_gpu(
|
||||
binary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,21 +0,0 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Add)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(ArcTan2)
|
||||
} // namespace mlx::core
|
||||
@@ -1,27 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Divide)
|
||||
} // namespace mlx::core
|
||||
@@ -1,15 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Greater)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(GreaterEqual)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Less)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(LessEqual)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(LogAddExp)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(LogicalAnd)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(LogicalOr)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Maximum)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Minimum)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Multiply)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(NotEqual)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Power)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Remainder)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/binary/binary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
BINARY_GPU(Subtract)
|
||||
} // namespace mlx::core
|
||||
@@ -127,99 +127,45 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename Op,
|
||||
typename In,
|
||||
typename Out,
|
||||
typename IdxT,
|
||||
int NDIM,
|
||||
int N_READS>
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_two_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
|
||||
auto shape_x = shape[NDIM - 1];
|
||||
auto a_stride_x = a_strides[NDIM - 1];
|
||||
auto b_stride_x = b_strides[NDIM - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto a_vec =
|
||||
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||
auto b_vec =
|
||||
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec_a;
|
||||
AlignedVector<Out, N_READS> out_vec_b;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||
out_vec_a[i] = out[0];
|
||||
out_vec_b[i] = out[1];
|
||||
}
|
||||
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
|
||||
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_two_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto a_stride_x = a_strides[ndim - 1];
|
||||
auto b_stride_x = b_strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index_rest * shape_x,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
ndim);
|
||||
auto a_vec =
|
||||
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
|
||||
auto b_vec =
|
||||
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec_a;
|
||||
AlignedVector<Out, N_READS> out_vec_b;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||
out_vec_a[i] = out[0];
|
||||
out_vec_b[i] = out[1];
|
||||
}
|
||||
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
|
||||
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
@@ -279,64 +225,42 @@ void binary_two_op_gpu_inplace(
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out_a.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel = cu::binary_two_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant(),
|
||||
1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::binary_two_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant(),
|
||||
4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(out_a, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::binary_two_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
rest,
|
||||
out_a.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(out_a, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::binary_two_g<Op, InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
rest,
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
|
||||
@@ -267,8 +267,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
false, std::move(builder.os), std::move(kernel_names));
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/conv/conv.h"
|
||||
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
// cudnn_frontend.h redefines this macro.
|
||||
#undef CHECK_CUDA_ERROR
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
@@ -15,6 +21,9 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Not all engines support it so can not use this API now.
|
||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||
|
||||
// Alias for better readability.
|
||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||
#define CONV_BACKWARD_INPUT \
|
||||
@@ -22,9 +31,6 @@ namespace {
|
||||
#define CONV_BACKWARD_WEIGHT \
|
||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||
|
||||
// Custom placeholder representing fallback kernel.
|
||||
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
|
||||
|
||||
struct ConvCacheKey {
|
||||
int device_id;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
@@ -44,13 +50,203 @@ struct ConvCacheKey {
|
||||
auto& conv_cache() {
|
||||
static LRUBytesKeyCache<
|
||||
ConvCacheKey,
|
||||
std::pair<
|
||||
cudnnBackendDescriptorType_t,
|
||||
std::optional<cudnn_frontend::ExecutionPlan>>>
|
||||
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
|
||||
cache(/* capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
template <typename T, typename Vec>
|
||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
return SmallVector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
auto nhwc_to_nchw(const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
auto strides = convert_vector<int64_t>(x.strides());
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(std::move(shape), std::move(strides));
|
||||
}
|
||||
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return CUDNN_DATA_INT8;
|
||||
case int32:
|
||||
return CUDNN_DATA_INT32;
|
||||
case uint8:
|
||||
return CUDNN_DATA_UINT8;
|
||||
case float16:
|
||||
return CUDNN_DATA_HALF;
|
||||
case bfloat16:
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
case float32:
|
||||
return CUDNN_DATA_FLOAT;
|
||||
case float64:
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
return alignment;
|
||||
}
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||
auto [shape, strides] = nhwc_to_nchw(x);
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(get_alignment(x))
|
||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||
.build();
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = false) {
|
||||
cudnn_frontend::GeneratorSource source;
|
||||
if (use_fallback) {
|
||||
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setOperation(backend_type)
|
||||
.build();
|
||||
return fallback.getFallbackList();
|
||||
};
|
||||
} else {
|
||||
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||
.build();
|
||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||
};
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||
auto configs = generator.generate_engine_config(op_graph);
|
||||
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||
if (cudnn_frontend::hasNumericalNote<
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||
dtype == float32 && !env::enable_tf32()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return filtered_configs;
|
||||
}
|
||||
|
||||
bool execute_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||
|
||||
int64_t uids[3] = {'x', 'w', 'y'};
|
||||
void* data_ptrs[3] = {
|
||||
x.data<void>(),
|
||||
w.data<void>(),
|
||||
y.data<void>(),
|
||||
};
|
||||
|
||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace.data<void>())
|
||||
.setDataPointers(3, data_ptrs)
|
||||
.setUids(3, uids)
|
||||
.build();
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||
cudaGraph_t graph;
|
||||
cudaGraphCreate(&graph, 0);
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||
if (cudnnBackendPopulateCudaGraph(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
#else
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool try_engines(
|
||||
cu::CommandEncoder& encoder,
|
||||
const ConvCacheKey& cache_key,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
cudnn_frontend::EngineConfigList& configs,
|
||||
const std::string& op_graph_tag,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y) {
|
||||
for (auto& config : configs) {
|
||||
try {
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
if (execute_plan(encoder, plan, x, w, y)) {
|
||||
conv_cache().emplace(
|
||||
cache_key, std::make_pair(backend_type, std::move(plan)));
|
||||
return true;
|
||||
}
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
auto get_conv_op_settings(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& x,
|
||||
@@ -95,7 +291,7 @@ auto get_conv_op_settings(
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
||||
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
@@ -121,9 +317,9 @@ std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
||||
.build();
|
||||
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_cudnn_tensor_nchw('x', x))
|
||||
.setwDesc(build_cudnn_tensor_nchw('w', w))
|
||||
.setyDesc(build_cudnn_tensor_nchw('y', y))
|
||||
.setxDesc(build_tensor('x', x))
|
||||
.setwDesc(build_tensor('w', w))
|
||||
.setyDesc(build_tensor('y', y))
|
||||
.setcDesc(conv_desc)
|
||||
.build();
|
||||
|
||||
@@ -140,42 +336,6 @@ std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
||||
array group_transpose(
|
||||
const array& x,
|
||||
int groups,
|
||||
int group_dim,
|
||||
int axis1,
|
||||
int axis2,
|
||||
Stream s) {
|
||||
if (groups == 1) {
|
||||
return swapaxes_in_eval(x, axis1, axis2);
|
||||
}
|
||||
int ndim = x.ndim();
|
||||
if (group_dim < 0) {
|
||||
group_dim += ndim;
|
||||
}
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
if (group_dim <= axis1) {
|
||||
axis1 += 1;
|
||||
}
|
||||
if (group_dim <= axis2) {
|
||||
axis2 += 1;
|
||||
}
|
||||
auto shape = x.shape();
|
||||
shape.insert(shape.begin() + group_dim, groups);
|
||||
shape[group_dim + 1] = shape[group_dim + 1] / groups;
|
||||
array x_trans = reshape_in_eval(x, std::move(shape), s);
|
||||
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
|
||||
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
|
||||
return x_trans;
|
||||
}
|
||||
|
||||
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||
// eval_gpu, with cost of possible redundant copies.
|
||||
@@ -185,14 +345,13 @@ std::tuple<array, array, array> prepare_args(
|
||||
array in,
|
||||
array wt,
|
||||
array out,
|
||||
int groups,
|
||||
Stream s) {
|
||||
// Transpose the args depending on the backend type.
|
||||
// TODO: Handle groups.
|
||||
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||
wt = group_transpose(wt, groups, 0, 0, -1, s);
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
in = group_transpose(in, groups, -1, 0, -1, s);
|
||||
in = swapaxes_in_eval(in, 0, -1);
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
// Create a contiguous array that shares the data with |out|, but with dim
|
||||
// C_in and C_out swapped.
|
||||
@@ -285,12 +444,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
ConvCacheKey cache_key{
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(dtype),
|
||||
vector_key(in.shape()),
|
||||
vector_key(wt.shape()),
|
||||
vector_key(kernel_strides_),
|
||||
vector_key(padding_lo_),
|
||||
vector_key(padding_hi_),
|
||||
vector_key(kernel_dilation_),
|
||||
fixed_vector(in.shape()),
|
||||
fixed_vector(wt.shape()),
|
||||
fixed_vector(kernel_strides_),
|
||||
fixed_vector(padding_lo_),
|
||||
fixed_vector(padding_hi_),
|
||||
fixed_vector(kernel_dilation_),
|
||||
groups_,
|
||||
flip_,
|
||||
get_alignment(in),
|
||||
@@ -298,29 +457,11 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
get_alignment(out)};
|
||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||
auto& [backend_type, plan] = it->second;
|
||||
if (plan) {
|
||||
// Run cached plan.
|
||||
std::tie(in, wt, out) =
|
||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||
}
|
||||
} else {
|
||||
// Run fallback kernel.
|
||||
gemm_conv(
|
||||
encoder,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_,
|
||||
s);
|
||||
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (!execute_plan(encoder, plan, x, w, y)) {
|
||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -349,7 +490,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||
for (auto try_backend : try_backends) {
|
||||
auto [in_copy, wt_copy, out_copy] =
|
||||
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
||||
prepare_args(encoder, try_backend, in, wt, out, s);
|
||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||
try_backend,
|
||||
@@ -361,7 +502,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
padding_hi_,
|
||||
kernel_dilation_,
|
||||
input_dilation_);
|
||||
op_graph = build_conv_op_graph(
|
||||
op_graph = build_op_graph(
|
||||
encoder,
|
||||
try_backend,
|
||||
dtype,
|
||||
@@ -380,39 +521,26 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (op_graph) {
|
||||
// Setup inputs and outputs.
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
|
||||
// Find a plan for the graph and execute it.
|
||||
auto plan = find_cudnn_plan_from_op_graph(
|
||||
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||
if (!plan) {
|
||||
throw std::runtime_error("[conv] Unable to find an execution plan.");
|
||||
}
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||
conv_cache().emplace(
|
||||
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||
return;
|
||||
}
|
||||
if (!op_graph) {
|
||||
throw std::runtime_error("[conv] Can not build op graph.");
|
||||
}
|
||||
|
||||
// Use fallback kernel for settings not supported by cuDNN.
|
||||
gemm_conv(
|
||||
encoder,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_,
|
||||
s);
|
||||
conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt));
|
||||
// Get ready to execute the graph.
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
|
||||
// Try to run plans based on heuristics.
|
||||
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||
auto tag = op_graph->getTag();
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||
return;
|
||||
}
|
||||
// Then try fallback plans.
|
||||
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("[conv] Unable to find a working engine.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <int NDIM>
|
||||
struct ConvParams {
|
||||
int N; // Batch size
|
||||
int C; // In channels
|
||||
int O; // Out channels
|
||||
int strides[NDIM];
|
||||
int padding[NDIM];
|
||||
int kernel_dilation[NDIM];
|
||||
int input_dilation[NDIM];
|
||||
int groups;
|
||||
bool flip;
|
||||
int in_spatial_dims[NDIM];
|
||||
int wt_spatial_dims[NDIM];
|
||||
int out_spatial_dims[NDIM];
|
||||
int64_t in_strides[NDIM + 2];
|
||||
|
||||
ConvParams(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
const array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
int groups,
|
||||
bool flip)
|
||||
: N(in.shape(0)),
|
||||
C(in.shape(-1)),
|
||||
O(wt.shape(0)),
|
||||
groups(groups),
|
||||
flip(flip) {
|
||||
std::copy_n(strides.begin(), NDIM, this->strides);
|
||||
std::copy_n(padding.begin(), NDIM, this->padding);
|
||||
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
|
||||
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
|
||||
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
|
||||
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
|
||||
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
|
||||
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
|
||||
}
|
||||
};
|
||||
|
||||
void gemm_grouped_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
Stream s);
|
||||
|
||||
void gemm_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
bool flip,
|
||||
Stream s);
|
||||
|
||||
inline void gemm_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
array in,
|
||||
array wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
Stream s) {
|
||||
if (!in.flags().row_contiguous) {
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
encoder.add_temporary(wt);
|
||||
}
|
||||
|
||||
if (groups == 1) {
|
||||
gemm_conv(
|
||||
encoder,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip,
|
||||
s);
|
||||
} else {
|
||||
gemm_grouped_conv(
|
||||
encoder,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
groups,
|
||||
flip,
|
||||
s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,217 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/conv/conv.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int NDIM>
|
||||
__global__ void naive_unfold_nd(
|
||||
const T* in,
|
||||
T* out,
|
||||
int filter_size,
|
||||
int out_pixels,
|
||||
const __grid_constant__ ConvParams<NDIM> params) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto tid = block.group_index();
|
||||
auto lid = block.thread_index();
|
||||
|
||||
int index_batch = tid.z / out_pixels; // [0, N)
|
||||
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||
int index_wt_spatial =
|
||||
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||
|
||||
if (index_wt_spatial >= filter_size / params.C) {
|
||||
return;
|
||||
}
|
||||
|
||||
in += tid.y; // [0, C)
|
||||
out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;
|
||||
|
||||
bool valid = index_batch < params.N;
|
||||
|
||||
// Get the coordinates in input.
|
||||
int index_in[NDIM] = {};
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||
|
||||
if (params.flip) {
|
||||
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||
}
|
||||
|
||||
int index = index_out * params.strides[i] - params.padding[i] +
|
||||
index_wt * params.kernel_dilation[i];
|
||||
int index_max =
|
||||
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||
|
||||
valid &= (index >= 0) && (index < index_max) &&
|
||||
(index % params.input_dilation[i] == 0);
|
||||
|
||||
index_in[i] = index / params.input_dilation[i];
|
||||
|
||||
index_out_spatial /= params.out_spatial_dims[i];
|
||||
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||
}
|
||||
|
||||
if (valid) {
|
||||
int in_offset = index_batch * params.in_strides[0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||
}
|
||||
*out = in[in_offset];
|
||||
} else {
|
||||
*out = T{0};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <int NDIM>
|
||||
array unfold_inputs_nd(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
int mat_M,
|
||||
int mat_K,
|
||||
int mat_N,
|
||||
ConvParams<NDIM>& params) {
|
||||
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
||||
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||
encoder.add_temporary(unfolded);
|
||||
|
||||
int filter_size = params.C;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
filter_size *= params.wt_spatial_dims[i];
|
||||
}
|
||||
|
||||
int out_pixels = 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
out_pixels *= params.out_spatial_dims[i];
|
||||
}
|
||||
|
||||
int wt_spatial_size = mat_K / params.C;
|
||||
dim3 block_dims;
|
||||
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||
dim3 num_blocks;
|
||||
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||
num_blocks.y = params.C;
|
||||
num_blocks.z = mat_M;
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(unfolded);
|
||||
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
encoder.add_kernel_node(
|
||||
cu::naive_unfold_nd<DataType, NDIM>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
unfolded.data<DataType>(),
|
||||
filter_size,
|
||||
out_pixels,
|
||||
params);
|
||||
});
|
||||
|
||||
return unfolded;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
void gemm_conv_nd(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
ConvParams<NDIM>& params,
|
||||
Stream s) {
|
||||
// Get gemm shapes.
|
||||
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
|
||||
int mat_N = params.O; // O
|
||||
|
||||
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||
array in_unfolded =
|
||||
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
|
||||
|
||||
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
|
||||
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
|
||||
wt_reshaped.copy_shared_buffer(
|
||||
wt,
|
||||
{1, mat_K},
|
||||
{false, false, /* col_contiguous */ true},
|
||||
wt.data_size());
|
||||
|
||||
// Single batch.
|
||||
Shape batch_shape{1};
|
||||
Strides a_batch_strides{0};
|
||||
Strides b_batch_strides{0};
|
||||
|
||||
// Run matmul.
|
||||
CublasGemm gemm(
|
||||
encoder.device(),
|
||||
in.dtype(),
|
||||
false, // a_transposed
|
||||
mat_M, // a_rows
|
||||
mat_K, // a_cols
|
||||
mat_K, // lda
|
||||
true, // b_transposed
|
||||
mat_K, // b_rows
|
||||
mat_N, // b_cols
|
||||
mat_K, // ldb
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
gemm.run(
|
||||
encoder,
|
||||
out,
|
||||
in_unfolded,
|
||||
wt_reshaped,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides);
|
||||
}
|
||||
|
||||
void gemm_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
bool flip,
|
||||
Stream s) {
|
||||
int conv_ndim = in.ndim() - 2;
|
||||
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||
}
|
||||
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||
ConvParams<ndim_constant()> params(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
1, // groups
|
||||
flip);
|
||||
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,231 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/conv/conv.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int NDIM>
|
||||
__global__ void naive_grouped_unfold_transpose_nd(
|
||||
const T* in,
|
||||
T* out,
|
||||
int filter_size,
|
||||
int out_pixels,
|
||||
const __grid_constant__ ConvParams<NDIM> params) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto tid = block.group_index();
|
||||
auto lid = block.thread_index();
|
||||
|
||||
int index_batch = tid.z / out_pixels; // [0, N)
|
||||
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||
int index_wt_spatial =
|
||||
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||
|
||||
if (index_wt_spatial >= filter_size / params.C) {
|
||||
return;
|
||||
}
|
||||
|
||||
in += tid.y; // [0, C)
|
||||
out += tid.z * filter_size + tid.y * (filter_size / params.C);
|
||||
|
||||
bool valid = index_batch < params.N;
|
||||
|
||||
// Get the coordinates in input.
|
||||
int index_in[NDIM] = {};
|
||||
int wt_stride = 1;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||
out += index_wt * wt_stride;
|
||||
|
||||
if (params.flip) {
|
||||
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||
}
|
||||
|
||||
int index = index_out * params.strides[i] - params.padding[i] +
|
||||
index_wt * params.kernel_dilation[i];
|
||||
int index_max =
|
||||
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||
|
||||
valid &= (index >= 0) && (index < index_max) &&
|
||||
(index % params.input_dilation[i] == 0);
|
||||
|
||||
index_in[i] = index / params.input_dilation[i];
|
||||
|
||||
index_out_spatial /= params.out_spatial_dims[i];
|
||||
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||
wt_stride *= params.wt_spatial_dims[i];
|
||||
}
|
||||
|
||||
if (valid) {
|
||||
int in_offset = index_batch * params.in_strides[0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||
}
|
||||
*out = in[in_offset];
|
||||
} else {
|
||||
*out = T{0};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <int NDIM>
|
||||
array grouped_unfold_transpose_inputs_nd(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
int mat_M,
|
||||
int mat_K,
|
||||
int mat_N,
|
||||
ConvParams<NDIM>& params) {
|
||||
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||
encoder.add_temporary(unfolded);
|
||||
|
||||
int filter_size = params.C;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
filter_size *= params.wt_spatial_dims[i];
|
||||
}
|
||||
|
||||
int out_pixels = 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
out_pixels *= params.out_spatial_dims[i];
|
||||
}
|
||||
|
||||
int wt_spatial_size = (mat_K * params.groups) / params.C;
|
||||
dim3 block_dims;
|
||||
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||
dim3 num_blocks;
|
||||
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||
num_blocks.y = params.C;
|
||||
num_blocks.z = mat_M;
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(unfolded);
|
||||
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
encoder.add_kernel_node(
|
||||
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
unfolded.data<DataType>(),
|
||||
filter_size,
|
||||
out_pixels,
|
||||
params);
|
||||
});
|
||||
|
||||
return unfolded;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
void gemm_grouped_conv_nd(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
ConvParams<NDIM>& params,
|
||||
Stream s) {
|
||||
// Get gemm shapes.
|
||||
int C_per_group = params.C / params.groups;
|
||||
int O_per_group = params.O / params.groups;
|
||||
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
|
||||
int mat_N = O_per_group; // O_per_group
|
||||
|
||||
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
|
||||
encoder, in, mat_M, mat_K, mat_N, params);
|
||||
|
||||
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
|
||||
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
|
||||
array wt_view(
|
||||
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
|
||||
wt_view.copy_shared_buffer(
|
||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
|
||||
|
||||
// Batch with size of groups.
|
||||
Shape batch_shape{params.groups};
|
||||
Strides a_batch_strides{mat_K};
|
||||
Strides b_batch_strides{mat_N * mat_K};
|
||||
|
||||
// Run matmul.
|
||||
CublasGemm gemm(
|
||||
encoder.device(),
|
||||
in.dtype(),
|
||||
false, // a_transposed
|
||||
mat_M, // a_rows
|
||||
mat_K, // a_cols
|
||||
mat_K * params.groups, // lda
|
||||
true, // b_transposed
|
||||
mat_K, // b_rows
|
||||
mat_N, // b_cols
|
||||
mat_K, // ldb
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
gemm.set_out(
|
||||
out.dtype(),
|
||||
false, // out_transposed
|
||||
mat_M, // out_rows
|
||||
mat_N, // out_cols
|
||||
mat_N * params.groups, // out_ld
|
||||
params.groups, // batch_count
|
||||
mat_N); // batch_stride
|
||||
gemm.run(
|
||||
encoder,
|
||||
out,
|
||||
in_unfolded,
|
||||
wt_reshaped,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides);
|
||||
}
|
||||
|
||||
void gemm_grouped_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
Stream s) {
|
||||
int conv_ndim = in.ndim() - 2;
|
||||
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||
}
|
||||
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||
ConvParams<ndim_constant()> params(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
groups,
|
||||
flip);
|
||||
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -10,80 +10,37 @@ namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_gg_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), strides_in.data(), strides_out.data());
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[NDIM - 1];
|
||||
auto in_stride_x = strides_in[NDIM - 1];
|
||||
auto out_stride_x = strides_out[NDIM - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||
index_rest * shape_x,
|
||||
shape.data(),
|
||||
strides_in.data(),
|
||||
strides_out.data());
|
||||
|
||||
auto in_vec =
|
||||
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_gg(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
const __grid_constant__ Strides strides_out,
|
||||
int ndim) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto in_stride_x = strides_in[ndim - 1];
|
||||
auto out_stride_x = strides_out[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [idx_in, idx_out] = elem_to_loc(
|
||||
index_rest * shape_x,
|
||||
shape.data(),
|
||||
strides_in.data(),
|
||||
strides_out.data(),
|
||||
ndim);
|
||||
|
||||
auto in_vec =
|
||||
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
@@ -112,52 +69,33 @@ void copy_general(
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
|
||||
int work_per_thread = 1;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = data_size / dim0;
|
||||
if (dim0 >= 4) {
|
||||
work_per_thread = 4;
|
||||
}
|
||||
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
auto kernel =
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel =
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(data_size, shape, out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
rest,
|
||||
data_size,
|
||||
const_param<ndim_constant()>(shape),
|
||||
const_param<ndim_constant()>(strides_in),
|
||||
const_param<ndim_constant()>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_gg<InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::copy_gg<InType, OutType, IdxT, 4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(data_size, shape, out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::copy_gg<InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
rest,
|
||||
data_size,
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
|
||||
@@ -10,67 +10,33 @@ namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_g_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[NDIM - 1];
|
||||
auto stride_x = strides[NDIM - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto idx =
|
||||
elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());
|
||||
auto in_vec =
|
||||
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
int ndim) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto stride_x = strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto idx =
|
||||
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
|
||||
auto in_vec =
|
||||
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
@@ -95,49 +61,30 @@ void copy_general_input(
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel =
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
rest,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::copy_g<InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
rest,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Create a cudnn tensor descriptor.
|
||||
template <typename Vec>
|
||||
inline cudnn_frontend::Tensor build_cudnn_tensor(
|
||||
int64_t id,
|
||||
const array& x,
|
||||
const Vec& shape,
|
||||
const Vec& strides) {
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(get_alignment(x))
|
||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||
.build();
|
||||
}
|
||||
|
||||
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
||||
// whether a tensor is contiguous is determined with:
|
||||
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
||||
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
||||
// as strided in cuDNN, and we work around it by normalizing the strides.
|
||||
Strides normalized_strides(const array& x) {
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return x.strides();
|
||||
}
|
||||
Strides strides = x.strides();
|
||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||
if (x.shape(i) == 1) {
|
||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||
assert(shape.size() >= 3);
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(std::move(shape), std::move(strides));
|
||||
}
|
||||
|
||||
inline auto nhwc_to_nchw(const array& x) {
|
||||
return nhwc_to_nchw(
|
||||
convert_vector<int64_t>(x.shape()), normalized_strides(x));
|
||||
}
|
||||
|
||||
// Return available engines for a |op_graph|.
|
||||
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = true) {
|
||||
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
|
||||
sources.push_back([](auto& op_graph) {
|
||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||
.build();
|
||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||
});
|
||||
if (use_fallback) {
|
||||
sources.push_back([&backend_type](auto& op_graph) {
|
||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setOperation(backend_type)
|
||||
.build();
|
||||
return fallback.getFallbackList();
|
||||
});
|
||||
}
|
||||
|
||||
auto configs =
|
||||
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
|
||||
.generate_engine_config(op_graph);
|
||||
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||
if (cudnn_frontend::hasNumericalNote<
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||
dtype == float32 && !env::enable_tf32()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return filtered_configs;
|
||||
}
|
||||
|
||||
// Take |engine_configs| and |op_graph| and find a working execution plans
|
||||
// from them.
|
||||
std::optional<cudnn_frontend::ExecutionPlan>
|
||||
find_cudnn_plan_from_engine_configs(
|
||||
cudnnHandle_t handle,
|
||||
const cudnn_frontend::EngineConfigList& engine_configs,
|
||||
const cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto op_graph_tag = op_graph.getTag();
|
||||
for (const auto& config : engine_configs) {
|
||||
try {
|
||||
return cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(handle)
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Prepare workspace and args to execute plan.
|
||||
template <typename F>
|
||||
bool prepare_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs,
|
||||
F&& execute) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(
|
||||
workspace_size > 0 ? allocator::malloc(workspace_size)
|
||||
: allocator::Buffer(nullptr),
|
||||
{workspace_size},
|
||||
uint8);
|
||||
|
||||
auto args = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace.data<void>())
|
||||
.setDataPointers(num_args, data_ptrs)
|
||||
.setUids(num_args, uids)
|
||||
.build();
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
||||
auto [shape, strides] = nhwc_to_nchw(x);
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
||||
if (x.ndim() == 0) {
|
||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
|
||||
}
|
||||
if (x.ndim() == 1) {
|
||||
int64_t s = x.shape(0);
|
||||
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
|
||||
SmallVector<int64_t, 4> strides = {s, 1, s, s};
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
}
|
||||
if (x.ndim() == 2) {
|
||||
int64_t s =
|
||||
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
|
||||
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
||||
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
}
|
||||
if (x.ndim() == 3 || x.ndim() == 4) {
|
||||
return build_cudnn_tensor_nchw(id, x);
|
||||
}
|
||||
throw std::runtime_error(
|
||||
fmt::format("Unsupported array with {} dims.", x.ndim()));
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
|
||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(scalar_dims.size(), scalar_dims.data())
|
||||
.setStrides(scalar_dims.size(), scalar_dims.data())
|
||||
.setId(id)
|
||||
.setAlignment(16)
|
||||
.setDataType(dtype_to_cudnn_type(dtype))
|
||||
.setByValue(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||
cudnnHandle_t handle,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
||||
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
||||
}
|
||||
|
||||
bool encode_cudnn_plan_with_capturing(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs) {
|
||||
return prepare_cudnn_plan(
|
||||
encoder,
|
||||
plan,
|
||||
num_args,
|
||||
uids,
|
||||
data_ptrs,
|
||||
[&](auto handle, auto plan, auto args) {
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
bool encode_cudnn_plan_with_graph_api(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs) {
|
||||
return prepare_cudnn_plan(
|
||||
encoder,
|
||||
plan,
|
||||
num_args,
|
||||
uids,
|
||||
data_ptrs,
|
||||
[&](auto handle, auto plan, auto args) {
|
||||
if (!graph) {
|
||||
graph = CudaGraph(encoder.device());
|
||||
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,164 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class CommandEncoder;
|
||||
}
|
||||
|
||||
// Return pointer alignment of |x|'s data.
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
return alignment;
|
||||
}
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
// Convert the type of elements in |vec| to |T|.
|
||||
template <typename T, typename Vec>
|
||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
return SmallVector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
||||
//
|
||||
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||
// 1. The rest of array is filled with 0.
|
||||
// 2. This util can be used in .cpp files.
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helpers used by get_data_ptrs to get pointers.
|
||||
inline void* get_data_ptr(const array& arr) {
|
||||
return const_cast<void*>(arr.data<void>());
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
||||
inline void* get_data_ptr(T& scalar) {
|
||||
return &scalar;
|
||||
}
|
||||
|
||||
// Return an array filled with data pointers of args.
|
||||
template <typename... Args>
|
||||
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
|
||||
return {get_data_ptr(args)...};
|
||||
}
|
||||
|
||||
// Map dtype to cudnn data type.
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return CUDNN_DATA_INT8;
|
||||
case int32:
|
||||
return CUDNN_DATA_INT32;
|
||||
case uint8:
|
||||
return CUDNN_DATA_UINT8;
|
||||
case float16:
|
||||
return CUDNN_DATA_HALF;
|
||||
case bfloat16:
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
case float32:
|
||||
return CUDNN_DATA_FLOAT;
|
||||
case float64:
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
// Create a tensor descriptor from |x|.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
|
||||
|
||||
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
|
||||
|
||||
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
|
||||
// from NHWC to NCHW.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
|
||||
|
||||
// Create a 4D scalar tensor descriptor, which is passed by value.
|
||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
|
||||
|
||||
// Find a working plan for |op_graph|.
|
||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||
cudnnHandle_t handle,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph);
|
||||
|
||||
// Encode the plan to command buffer by capturing.
|
||||
bool encode_cudnn_plan_with_capturing(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs);
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
// Encode the plan to command buffer by using native graph api of cudnn. If the
|
||||
// |graph| is empty it will be populated, otherwise it will be updated.
|
||||
bool encode_cudnn_plan_with_graph_api(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs);
|
||||
#endif
|
||||
|
||||
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
|
||||
template <typename... Args>
|
||||
bool encode_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
std::initializer_list<int64_t> uids,
|
||||
Args&... args) {
|
||||
assert(uids.size() == sizeof...(args));
|
||||
auto data_ptrs = get_data_ptrs(args...);
|
||||
return encode_cudnn_plan_with_capturing(
|
||||
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
|
||||
}
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
template <typename... Args>
|
||||
bool encode_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
std::initializer_list<int64_t> uids,
|
||||
Args&... args) {
|
||||
assert(uids.size() == sizeof...(args));
|
||||
auto data_ptrs = get_data_ptrs(args...);
|
||||
return encode_cudnn_plan_with_graph_api(
|
||||
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,379 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* default_header = R"(
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||
|
||||
)";
|
||||
|
||||
std::string template_arguments_hash(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
if (template_args.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string hash;
|
||||
hash.reserve(512);
|
||||
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
hash += fmt::format("_{}", std::get<int>(arg));
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
hash += (std::get<bool>(arg)) ? "_t" : "_f";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
hash += "_";
|
||||
hash += get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
std::string build_kernel(
|
||||
const std::string& func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
||||
kernel_source += default_header;
|
||||
kernel_source += header;
|
||||
kernel_source +=
|
||||
"namespace mlx::core::cu {\n\n"
|
||||
"namespace cg = cooperative_groups;\n\n";
|
||||
|
||||
kernel_source += "__global__ void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
kernel_source += " const ";
|
||||
kernel_source += dtype_to_cuda_type(arr.dtype());
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += ",\n";
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source += " const __grid_constant__ Shape ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_shape,\n";
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source += " const __grid_constant__ Strides ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_strides,\n";
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source += " const __grid_constant__ int ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_ndim,\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype_to_cuda_type(dtype);
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
if (i < output_names.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Set compile time constants
|
||||
if (!template_args.empty()) {
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
kernel_source +=
|
||||
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
kernel_source += fmt::format(
|
||||
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
|
||||
} else {
|
||||
kernel_source += fmt::format(
|
||||
" using {} = {};\n",
|
||||
name,
|
||||
dtype_to_cuda_type(std::get<Dtype>(arg)));
|
||||
}
|
||||
}
|
||||
kernel_source += "\n";
|
||||
}
|
||||
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
|
||||
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_memory) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[custom_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
|
||||
return [=, shape_infos = std::move(shape_infos)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::string kernel_name =
|
||||
"custom_kernel_" + name + template_arguments_hash(template_args);
|
||||
std::string kernel_source = build_kernel(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos);
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << kernel_name
|
||||
<< "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value,
|
||||
std::vector<ScalarArg>{},
|
||||
false,
|
||||
shared_memory),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<ScalarArg>& scalars,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice s) {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
to_stream(s),
|
||||
name,
|
||||
compiled_source,
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value,
|
||||
scalars,
|
||||
true,
|
||||
shared_memory),
|
||||
inputs);
|
||||
}
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("CustomKernel::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
// Allocate and initialize the output arrays
|
||||
for (auto& out : outputs) {
|
||||
if (init_value_) {
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
// Create the input arrays and copy if needed
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
std::vector<array> checked_inputs;
|
||||
for (const array& in : inputs) {
|
||||
checked_inputs.push_back(check_input(in));
|
||||
}
|
||||
|
||||
// Compile the custom kernel
|
||||
std::string kernel_name =
|
||||
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
|
||||
cu::JitModule& mod = cu::get_jit_module(
|
||||
s.device,
|
||||
name_,
|
||||
[&]() {
|
||||
return std::make_tuple(
|
||||
is_precompiled_, source_, std::vector{kernel_name});
|
||||
},
|
||||
false);
|
||||
|
||||
// Make the arguments
|
||||
cu::KernelArgs args;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
args.append(in);
|
||||
if (shape_info.shape) {
|
||||
args.append_ndim(in.shape());
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
args.append_ndim(in.strides());
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
args.append<int32_t>(in.ndim());
|
||||
}
|
||||
}
|
||||
for (auto& out : outputs) {
|
||||
args.append(out);
|
||||
}
|
||||
for (auto& s : scalar_arguments_) {
|
||||
if (std::holds_alternative<bool>(s)) {
|
||||
args.append(std::get<bool>(s));
|
||||
} else if (std::holds_alternative<int>(s)) {
|
||||
args.append(std::get<int>(s));
|
||||
} else if (std::holds_alternative<float>(s)) {
|
||||
args.append(std::get<float>(s));
|
||||
}
|
||||
}
|
||||
|
||||
// Make the grid
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
||||
|
||||
// Call the kernel
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : checked_inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
for (const auto& t : copies) {
|
||||
encoder.add_temporary(t);
|
||||
}
|
||||
auto kernel =
|
||||
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
|
||||
if (smem > 0 && smem > 48000) {
|
||||
cuFuncSetAttribute(
|
||||
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
|
||||
}
|
||||
});
|
||||
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
@@ -91,7 +91,9 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
graph.end_capture(enc.stream());
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||
if (discard) {
|
||||
return;
|
||||
}
|
||||
@@ -183,10 +185,9 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d),
|
||||
stream_(d),
|
||||
graph_(d),
|
||||
graph_cache_(cuda_graph_cache_size()) {}
|
||||
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
@@ -310,7 +311,8 @@ void CommandEncoder::commit() {
|
||||
to_nodes_.clear();
|
||||
graph_key_.clear();
|
||||
node_map_.clear();
|
||||
graph_ = CudaGraph(device_);
|
||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
|
||||
@@ -21,7 +21,7 @@ class CommandEncoder {
|
||||
struct CaptureContext {
|
||||
CaptureContext(CommandEncoder& enc);
|
||||
~CaptureContext();
|
||||
CudaGraph graph;
|
||||
cudaGraph_t graph;
|
||||
CommandEncoder& enc;
|
||||
bool discard{false};
|
||||
};
|
||||
@@ -115,7 +115,7 @@ class CommandEncoder {
|
||||
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
CudaGraph graph_;
|
||||
cudaGraph_t graph_;
|
||||
Worker worker_;
|
||||
char node_count_{0};
|
||||
char graph_node_count_{0};
|
||||
|
||||
@@ -146,23 +146,6 @@ inline __device__ void store_vector(
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T, typename SizeT>
|
||||
inline __device__ void store_vector(
|
||||
T* ptr,
|
||||
uint32_t offset,
|
||||
const AlignedVector<T, N>& vec,
|
||||
SizeT size,
|
||||
int64_t stride) {
|
||||
if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
} else {
|
||||
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
|
||||
ptr[stride * (offset * N + i)] = vec[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -4,16 +4,16 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
namespace mlx::core::cu {
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
void Matmul::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
@@ -22,7 +22,7 @@ void CublasGemm::run_batched(
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
auto concurrent = encoder.concurrent_context();
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
execute(
|
||||
run_impl(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
@@ -33,16 +33,16 @@ void CublasGemm::run_batched(
|
||||
}
|
||||
}
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
void Matmul::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides,
|
||||
const mlx::core::Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
encoder.set_input_array(a);
|
||||
@@ -56,7 +56,7 @@ void CublasGemm::run_batched(
|
||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||
auto concurrent = encoder.concurrent_context();
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
execute(
|
||||
run_impl(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
@@ -70,4 +70,4 @@ void CublasGemm::run_batched(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core::cu
|
||||
208
mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu
Normal file
208
mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu
Normal file
@@ -0,0 +1,208 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__global__ void set_mm_device_pointers(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides a_batch_strides,
|
||||
const __grid_constant__ Strides b_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_ndim,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset] = elem_to_loc(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
batch_ndim);
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_addmm_device_pointers(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* c_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides a_batch_strides,
|
||||
const __grid_constant__ Strides b_batch_strides,
|
||||
const __grid_constant__ Strides c_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_ndim,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset, c_offset] = elem_to_loc(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
c_batch_strides.data(),
|
||||
batch_ndim);
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||
pointers[index + 3 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
||||
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
||||
&batch_mode,
|
||||
sizeof(batch_mode)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
||||
}
|
||||
|
||||
void Matmul::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides) {
|
||||
auto batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
|
||||
{static_cast<int>(batch_count * 3)},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
int block_size = 512;
|
||||
encoder.set_output_array(pointers);
|
||||
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
static_cast<int>(out.dtype().size()),
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
static_cast<int64_t>(M_) * N_,
|
||||
static_cast<int>(batch_shape.size()),
|
||||
batch_count);
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto out_pointers = b_pointers + batch_count;
|
||||
run_impl(
|
||||
encoder,
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void Matmul::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides,
|
||||
const mlx::core::Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
auto batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(c_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||
{static_cast<int>(batch_count * 4)},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
int block_size = 512;
|
||||
encoder.set_output_array(pointers);
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
static_cast<int>(out.dtype().size()),
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
const_param(c_batch_strides),
|
||||
static_cast<int64_t>(M_) * N_,
|
||||
static_cast<int>(batch_shape.size()),
|
||||
batch_count);
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto c_pointers = b_pointers + batch_count;
|
||||
auto out_pointers = c_pointers + batch_count;
|
||||
run_impl(
|
||||
encoder,
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
reinterpret_cast<void*>(c_pointers),
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -7,12 +7,10 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct CublasPreference {
|
||||
CublasPreference(cu::Device& device) {
|
||||
CublasPreference(Device& device) {
|
||||
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
|
||||
// for Hopper+:
|
||||
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
|
||||
@@ -35,7 +33,7 @@ struct CublasPreference {
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
};
|
||||
|
||||
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
|
||||
cublasLtMatmulPreference_t cublas_preference(Device& device) {
|
||||
static CublasPreference pref(device);
|
||||
return pref.pref_;
|
||||
}
|
||||
@@ -54,7 +52,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +70,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
|
||||
return CUDA_C_32F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,10 +102,8 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
||||
return desc;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CublasGemm::CublasGemm(
|
||||
cu::Device& device,
|
||||
Matmul::Matmul(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -159,8 +155,8 @@ CublasGemm::CublasGemm(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
}
|
||||
|
||||
CublasGemm::CublasGemm(
|
||||
cu::Device& device,
|
||||
Matmul::Matmul(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -175,7 +171,7 @@ CublasGemm::CublasGemm(
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride,
|
||||
int64_t c_batch_stride)
|
||||
: CublasGemm(
|
||||
: Matmul(
|
||||
device,
|
||||
dtype,
|
||||
a_transposed,
|
||||
@@ -194,7 +190,7 @@ CublasGemm::CublasGemm(
|
||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
CublasGemm::~CublasGemm() {
|
||||
Matmul::~Matmul() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
|
||||
@@ -202,92 +198,7 @@ CublasGemm::~CublasGemm() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||
}
|
||||
|
||||
void CublasGemm::set_out(
|
||||
Dtype dtype,
|
||||
bool transposed,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
int64_t ld,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||
out_desc_ = create_matrix_layout(
|
||||
dtype_to_cublas_type(dtype),
|
||||
rows,
|
||||
cols,
|
||||
transposed,
|
||||
ld,
|
||||
batch_count,
|
||||
batch_stride);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
return;
|
||||
}
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
c_batch_strides,
|
||||
alpha,
|
||||
beta);
|
||||
return;
|
||||
}
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
c.data<void>(),
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
void CublasGemm::execute(
|
||||
void Matmul::run_impl(
|
||||
cu::CommandEncoder& encoder,
|
||||
void* out,
|
||||
const void* a,
|
||||
@@ -345,4 +256,29 @@ void CublasGemm::execute(
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
void Matmul::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::optional<array>& c /* = std::nullopt */,
|
||||
float alpha /* = 1 */,
|
||||
float beta /* = 0 */) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
if (c) {
|
||||
encoder.set_input_array(*c);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
|
||||
run_impl(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
c ? c->data<void>() : nullptr,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -5,13 +5,13 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <optional>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
class CublasGemm {
|
||||
namespace mlx::core::cu {
|
||||
class Matmul {
|
||||
public:
|
||||
CublasGemm(
|
||||
cu::Device& device,
|
||||
Matmul(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -25,8 +25,8 @@ class CublasGemm {
|
||||
int64_t a_batch_stride,
|
||||
int64_t b_batch_stride);
|
||||
|
||||
CublasGemm(
|
||||
cu::Device& device,
|
||||
Matmul(
|
||||
Device& device,
|
||||
Dtype dtype,
|
||||
bool a_transposed,
|
||||
uint64_t a_rows,
|
||||
@@ -42,65 +42,41 @@ class CublasGemm {
|
||||
int64_t b_batch_stride,
|
||||
int64_t c_batch_stride);
|
||||
|
||||
~CublasGemm();
|
||||
|
||||
// The output's descriptor is inferred from inputs by default, use this method
|
||||
// for unusual output.
|
||||
void set_out(
|
||||
Dtype dtype,
|
||||
bool transposed,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
int64_t ld,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride);
|
||||
~Matmul();
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
const std::optional<array>& c = std::nullopt,
|
||||
float alpha = 1,
|
||||
float beta = 0);
|
||||
|
||||
void run(
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides);
|
||||
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides,
|
||||
const mlx::core::Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta);
|
||||
|
||||
private:
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta);
|
||||
|
||||
void execute(
|
||||
void run_impl(
|
||||
cu::CommandEncoder& encoder,
|
||||
void* out,
|
||||
const void* a,
|
||||
@@ -121,4 +97,4 @@ class CublasGemm {
|
||||
cublasLtMatmulHeuristicResult_t heuristic_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <int NDIM>
|
||||
__global__ void set_mm_device_pointers_nd(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data());
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_mm_device_pointers_g(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides a_batch_strides,
|
||||
const __grid_constant__ Strides b_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_ndim,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset] = elem_to_loc(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
batch_ndim);
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
__global__ void set_addmm_device_pointers_nd(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* c_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
c_batch_strides.data());
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||
pointers[index + 3 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_addmm_device_pointers_g(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* c_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ Shape batch_shape,
|
||||
const __grid_constant__ Strides a_batch_strides,
|
||||
const __grid_constant__ Strides b_batch_strides,
|
||||
const __grid_constant__ Strides c_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_ndim,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset, c_offset] = elem_to_loc(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
c_batch_strides.data(),
|
||||
batch_ndim);
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||
pointers[index + 3 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace {
|
||||
|
||||
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
|
||||
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
||||
&batch_mode,
|
||||
sizeof(batch_mode)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(void*) * 3),
|
||||
{batch_count * 3},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
encoder.set_output_array(pointers);
|
||||
|
||||
int block_dims = std::min(batch_count, 256);
|
||||
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||
int64_t batch_stride = M_ * N_;
|
||||
int item_size = out.itemsize();
|
||||
|
||||
int ndim = batch_shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers_nd<ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param<ndim_constant()>(batch_shape),
|
||||
const_param<ndim_constant()>(a_batch_strides),
|
||||
const_param<ndim_constant()>(b_batch_strides),
|
||||
batch_stride,
|
||||
batch_count);
|
||||
});
|
||||
} else {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers_g,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
batch_stride,
|
||||
ndim,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto out_pointers = b_pointers + batch_count;
|
||||
execute(
|
||||
encoder,
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides,
|
||||
const Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(c_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||
{batch_count * 4},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
encoder.set_output_array(pointers);
|
||||
|
||||
int block_dims = std::min(batch_count, 256);
|
||||
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||
int64_t batch_stride = M_ * N_;
|
||||
int item_size = out.itemsize();
|
||||
|
||||
int ndim = batch_shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers_nd<ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param<ndim_constant()>(batch_shape),
|
||||
const_param<ndim_constant()>(a_batch_strides),
|
||||
const_param<ndim_constant()>(b_batch_strides),
|
||||
const_param<ndim_constant()>(c_batch_strides),
|
||||
batch_stride,
|
||||
batch_count);
|
||||
});
|
||||
} else {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers_g,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
const_param(c_batch_strides),
|
||||
batch_stride,
|
||||
ndim,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto c_pointers = b_pointers + batch_count;
|
||||
auto out_pointers = c_pointers + batch_count;
|
||||
execute(
|
||||
encoder,
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
reinterpret_cast<void*>(c_pointers),
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,396 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/arch/arch.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/layout/matrix.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace cute;
|
||||
using bf16 = cute::bfloat16_t;
|
||||
|
||||
template <typename Kernel>
|
||||
void configure_matmul(Kernel kernel, int smem_size) {
|
||||
static bool initialized = false;
|
||||
if (!initialized) {
|
||||
initialized = true;
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool transpose, typename Tiler>
|
||||
constexpr int get_feature_size(Tiler smem) {
|
||||
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
|
||||
return (feature_size >= 64) ? 64 : feature_size;
|
||||
}
|
||||
|
||||
constexpr int constexpr_log2(int x) {
|
||||
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
|
||||
}
|
||||
|
||||
template <int feature_size, int itemsize, int copy_bits>
|
||||
constexpr int get_swizzle_bits() {
|
||||
constexpr int swizzle_bits =
|
||||
constexpr_log2(feature_size * itemsize / copy_bits);
|
||||
return (swizzle_bits > 3) ? 3 : swizzle_bits;
|
||||
}
|
||||
|
||||
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
||||
constexpr auto make_smem_layout(Tiler smem) {
|
||||
constexpr int feature_size = get_feature_size<transpose>(smem);
|
||||
constexpr int swizzle_bits =
|
||||
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
||||
|
||||
using F = Int<feature_size>;
|
||||
using BaseLayout = std::conditional_t<
|
||||
transpose,
|
||||
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
||||
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
||||
|
||||
auto swizzled =
|
||||
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
|
||||
|
||||
return tile_to_shape(swizzled, smem);
|
||||
}
|
||||
|
||||
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
|
||||
constexpr auto make_result_smem_layout(Tiler smem) {
|
||||
constexpr int feature_size = get_feature_size<transpose>(smem);
|
||||
constexpr int swizzle_bits =
|
||||
get_swizzle_bits<feature_size, itemsize, copy_bits>();
|
||||
|
||||
using F = Int<feature_size>;
|
||||
using BaseLayout = std::conditional_t<
|
||||
transpose,
|
||||
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
|
||||
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
|
||||
|
||||
auto swizzled = make_composed_layout(
|
||||
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
|
||||
|
||||
return tile_to_shape(swizzled, smem);
|
||||
}
|
||||
|
||||
template <
|
||||
int num_threads,
|
||||
int itemsize,
|
||||
bool transpose,
|
||||
int copy_bits,
|
||||
typename Copier,
|
||||
typename Tiler>
|
||||
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
|
||||
constexpr int num_elements = copy_bits / itemsize;
|
||||
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
|
||||
constexpr int copies_per_feature = feature_size / num_elements;
|
||||
|
||||
using E = Int<num_elements>;
|
||||
using C = Int<copies_per_feature>;
|
||||
using R = Int<num_threads / copies_per_feature>;
|
||||
|
||||
using ThreadLayout = std::conditional_t<
|
||||
transpose,
|
||||
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
|
||||
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
|
||||
using ValueLayout = std::conditional_t<
|
||||
transpose,
|
||||
Layout<cute::Shape<E, _1>>,
|
||||
Layout<cute::Shape<_1, E>>>;
|
||||
|
||||
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
|
||||
}
|
||||
|
||||
template <int rasterization_factor>
|
||||
__device__ inline int2 raster_tile(int x, int y) {
|
||||
return {
|
||||
x / rasterization_factor,
|
||||
(x % rasterization_factor) + y * rasterization_factor};
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename SLayoutA,
|
||||
typename SLayoutB,
|
||||
typename SLayoutC,
|
||||
typename CopyA,
|
||||
typename CopyB,
|
||||
typename CopyC,
|
||||
typename MMA,
|
||||
int rasterization_factor>
|
||||
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
|
||||
const T* __restrict__ A,
|
||||
const T* __restrict__ B,
|
||||
T* __restrict__ C,
|
||||
SLayoutA SA,
|
||||
SLayoutB SB,
|
||||
SLayoutC SC,
|
||||
CopyA copy_a,
|
||||
CopyB copy_b,
|
||||
CopyC copy_c,
|
||||
MMA mma,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr auto BM = size<0>(SA);
|
||||
constexpr auto BN = size<0>(SB);
|
||||
constexpr auto BK = size<1>(SA);
|
||||
constexpr auto PIPE = size<2>(SA);
|
||||
|
||||
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
|
||||
const int blocks_m = ceil_div(M, BM);
|
||||
const int blocks_n = ceil_div(N, BN);
|
||||
|
||||
// Exit early if the tile is OOB
|
||||
if (tile.x >= blocks_m || tile.y >= blocks_n) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Make the full tensors
|
||||
Tensor full_A =
|
||||
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
|
||||
Tensor full_B =
|
||||
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
|
||||
Tensor full_C =
|
||||
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
|
||||
|
||||
// Partition the tensors into tiles and select the ones for this threadblock
|
||||
Tensor local_A =
|
||||
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
|
||||
Tensor local_B =
|
||||
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
|
||||
Tensor local_C =
|
||||
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
|
||||
|
||||
// Make shared memory tensors
|
||||
extern __shared__ char shared_memory[];
|
||||
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
|
||||
T* shared_B_ptr =
|
||||
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
|
||||
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
|
||||
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
|
||||
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
|
||||
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
|
||||
|
||||
// Get the copies that correspond to this thread
|
||||
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
|
||||
Tensor local_A_src = thread_copy_a.partition_S(local_A);
|
||||
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
|
||||
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
|
||||
Tensor local_B_src = thread_copy_a.partition_S(local_B);
|
||||
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
|
||||
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
|
||||
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
|
||||
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
|
||||
|
||||
// Start fetches
|
||||
int k_tile_count = size<2>(local_A);
|
||||
int k_tile_next = 0;
|
||||
CUTE_UNROLL
|
||||
for (int k = 0; k < PIPE - 1; k++) {
|
||||
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
|
||||
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
|
||||
cp_async_fence();
|
||||
k_tile_count--;
|
||||
k_tile_next += (k_tile_count > 0);
|
||||
}
|
||||
|
||||
// Get the MMA that corresponds to this thread and allocate registers
|
||||
auto thread_mma = mma.get_slice(threadIdx.x);
|
||||
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
|
||||
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
|
||||
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
|
||||
Tensor mma_global_C = thread_mma.partition_C(local_C);
|
||||
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
|
||||
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
|
||||
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
|
||||
clear(mma_frag_C);
|
||||
|
||||
// Make shared to register copies
|
||||
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
|
||||
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
|
||||
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
|
||||
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
|
||||
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
|
||||
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
|
||||
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
|
||||
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
|
||||
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
|
||||
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
|
||||
|
||||
constexpr auto RPIPE = size<2>(mma_shared_A);
|
||||
int smem_read = 0;
|
||||
int smem_write = PIPE - 1;
|
||||
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
||||
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
||||
|
||||
// Start the register pipeline
|
||||
if constexpr (RPIPE > 1) {
|
||||
cp_async_wait<PIPE - 2>();
|
||||
__syncthreads();
|
||||
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
|
||||
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
|
||||
}
|
||||
|
||||
CUTE_NO_UNROLL
|
||||
while (k_tile_count > -(PIPE - 1)) {
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < RPIPE; k_block++) {
|
||||
if (k_block == RPIPE - 1) {
|
||||
mma_A_src_p = mma_A_src(_, _, _, smem_read);
|
||||
mma_B_src_p = mma_B_src(_, _, _, smem_read);
|
||||
cp_async_wait<PIPE - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the next register tile
|
||||
auto k_block_next = (k_block + 1) % RPIPE;
|
||||
copy(
|
||||
s2r_copy_a,
|
||||
mma_A_src_p(_, _, k_block_next),
|
||||
mma_A_dst(_, _, k_block_next));
|
||||
copy(
|
||||
s2r_copy_b,
|
||||
mma_B_src_p(_, _, k_block_next),
|
||||
mma_B_dst(_, _, k_block_next));
|
||||
|
||||
if (k_block == 0) {
|
||||
copy(
|
||||
copy_a,
|
||||
local_A_src(_, _, _, k_tile_next),
|
||||
local_A_dst(_, _, _, smem_write));
|
||||
copy(
|
||||
copy_b,
|
||||
local_B_src(_, _, _, k_tile_next),
|
||||
local_B_dst(_, _, _, smem_write));
|
||||
cp_async_fence();
|
||||
k_tile_count--;
|
||||
k_tile_next += (k_tile_count > 0);
|
||||
smem_write = smem_read;
|
||||
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
|
||||
}
|
||||
|
||||
gemm(
|
||||
mma,
|
||||
mma_frag_A(_, _, k_block),
|
||||
mma_frag_B(_, _, k_block),
|
||||
mma_frag_C);
|
||||
}
|
||||
}
|
||||
|
||||
copy(mma_frag_C, mma_shared_C);
|
||||
__syncthreads();
|
||||
copy(copy_c, local_C_src, local_C_dst);
|
||||
|
||||
// if (threadIdx.x == 0) {
|
||||
// print("fC: "); print(mma_frag_C); print("\n");
|
||||
// print("sC: "); print(mma_shared_C); print("\n");
|
||||
// print("dC: "); print(local_C_dst); print("\n");
|
||||
//
|
||||
// print(s2r_atom_a); print("\n");
|
||||
// }
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_gemm(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc) {
|
||||
enc.set_input_array(a);
|
||||
enc.set_input_array(b);
|
||||
enc.set_output_array(out);
|
||||
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
|
||||
using namespace cute;
|
||||
|
||||
// Tile definitions
|
||||
auto BM = Int<128>{};
|
||||
auto BN = Int<128>{};
|
||||
auto BK = Int<64>{};
|
||||
auto BP = Int<3>{};
|
||||
auto GM = Int<8>{};
|
||||
|
||||
// Thread definitions
|
||||
using TM = Int<2>;
|
||||
using TN = Int<2>;
|
||||
using TK = Int<1>;
|
||||
constexpr int num_threads = TM::value * TN::value * 32;
|
||||
|
||||
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
|
||||
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
|
||||
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
|
||||
|
||||
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
|
||||
|
||||
auto async_copy_op =
|
||||
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
|
||||
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
|
||||
async_copy_op, make_shape(BM, BK));
|
||||
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
|
||||
async_copy_op, make_shape(BN, BK));
|
||||
|
||||
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
|
||||
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
|
||||
sync_copy_op, make_shape(BM, BN));
|
||||
|
||||
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
|
||||
auto tiled_mma = make_tiled_mma(
|
||||
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
|
||||
|
||||
auto kernel = matmul_kernel<
|
||||
bf16,
|
||||
decltype(SA),
|
||||
decltype(SB),
|
||||
decltype(SC),
|
||||
decltype(tiled_copy_a),
|
||||
decltype(tiled_copy_b),
|
||||
decltype(tiled_copy_c),
|
||||
decltype(tiled_mma),
|
||||
GM.value>;
|
||||
configure_matmul(kernel, smem_size);
|
||||
|
||||
dim3 block(size(tiled_mma));
|
||||
dim3 grid(
|
||||
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
|
||||
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
smem_size,
|
||||
a.data<bf16>(),
|
||||
b.data<bf16>(),
|
||||
out.data<bf16>(),
|
||||
SA,
|
||||
SB,
|
||||
SC,
|
||||
tiled_copy_a,
|
||||
tiled_copy_b,
|
||||
tiled_copy_c,
|
||||
tiled_mma,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else {
|
||||
throw std::runtime_error("Only bfloat16 supported");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
void cutlass_gemm(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc);
|
||||
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/steel/gemm.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Kernel>
|
||||
static void configure_smem(Kernel kernel, int SM) {
|
||||
static bool done = false;
|
||||
if (done) {
|
||||
return;
|
||||
}
|
||||
std::cout << "configuring" << std::endl;
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
|
||||
cudaFuncSetAttribute(
|
||||
kernel,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
cudaSharedmemCarveoutMaxShared);
|
||||
done = true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void simple_gemm(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc) {
|
||||
enc.set_input_array(a);
|
||||
enc.set_input_array(b);
|
||||
enc.set_output_array(out);
|
||||
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr int BM = 128;
|
||||
constexpr int BN = 128;
|
||||
constexpr int BK = 32;
|
||||
constexpr int PIPE = 3;
|
||||
constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 4;
|
||||
|
||||
auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
|
||||
configure_smem(kernel, SM);
|
||||
|
||||
dim3 grid(N / BN, M / BM);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
WM * WN * WARP_SIZE,
|
||||
SM,
|
||||
a.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
N,
|
||||
K);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
void simple_gemm(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc);
|
||||
|
||||
}
|
||||
@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
});
|
||||
|
||||
cu::KernelArgs args;
|
||||
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
});
|
||||
|
||||
cu::KernelArgs args;
|
||||
@@ -268,8 +268,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(
|
||||
false, jit_source_gather_axis, std::move(kernel_names));
|
||||
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
@@ -372,8 +371,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(
|
||||
false, jit_source_scatter_axis, std::move(kernel_names));
|
||||
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
|
||||
@@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
|
||||
bool read_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
std::string& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return false;
|
||||
}
|
||||
@@ -117,15 +117,15 @@ bool read_cached_ptx(
|
||||
if (!ptx_file.good()) {
|
||||
return false;
|
||||
}
|
||||
ptx.resize(ptx_size);
|
||||
ptx_file.read(ptx.data(), ptx_size);
|
||||
ptx->resize(ptx_size);
|
||||
ptx_file.read(ptx->data(), ptx_size);
|
||||
|
||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::string line;
|
||||
while (std::getline(txt_file, line)) {
|
||||
auto tab = line.find('\t');
|
||||
if (tab != std::string::npos) {
|
||||
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@@ -135,7 +135,7 @@ bool read_cached_ptx(
|
||||
void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::string& ptx,
|
||||
const std::vector<char>& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
const std::string& source_code) {
|
||||
if (cache_dir.empty()) {
|
||||
@@ -217,85 +217,85 @@ constexpr const char* g_headers[] = {
|
||||
jit_source_utils,
|
||||
};
|
||||
|
||||
void compile(
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& kernel_names,
|
||||
std::string& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
// Create the program
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source_code.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
}
|
||||
|
||||
void load_module(
|
||||
const std::string& module_name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
CUmodule& module_,
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
@@ -312,69 +312,21 @@ void load_module(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels[name] = std::make_pair(kernel, false);
|
||||
kernels_[name] = kernel;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder,
|
||||
bool use_disk_cache) {
|
||||
// Will hold the actual device executable source code and kernel names
|
||||
std::string ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
|
||||
// Try to load them from the file cache
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
|
||||
auto [precompiled, source_code, kernel_names] = builder();
|
||||
|
||||
// Get the PTX or cubin
|
||||
if (precompiled) {
|
||||
ptx = std::move(source_code);
|
||||
for (auto& name : kernel_names) {
|
||||
ptx_kernels.emplace_back(name, name);
|
||||
}
|
||||
} else {
|
||||
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// If requested save them in the file cache for the next launch
|
||||
if (use_disk_cache) {
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the module
|
||||
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
if (it == kernels_.end()) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("There is no kernel named {}.", kernel_name));
|
||||
}
|
||||
|
||||
// If it is the first time we run this kernel then configure it. Do it only
|
||||
// once!
|
||||
if (!it->second.second) {
|
||||
if (configure_kernel) {
|
||||
configure_kernel(it->second.first);
|
||||
}
|
||||
it->second.second = true;
|
||||
}
|
||||
|
||||
return it->second.first;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
@@ -385,12 +337,11 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder,
|
||||
bool cache) {
|
||||
const KernelBuilder& builder) {
|
||||
auto& map = get_jit_module_cache();
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
|
||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
@@ -19,8 +19,7 @@ namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
using KernelBuilderResult = std::tuple<
|
||||
/* precompiled */ bool,
|
||||
using KernelBuilderResult = std::pair<
|
||||
/* source code */ std::string,
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
@@ -64,16 +63,14 @@ struct KernelArgs {
|
||||
private:
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuGraphAddKernelNode API requires passing pointers to arguments so
|
||||
// store temporary values until the node is created.
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
// temporary values untill kernel is launched.
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
bool,
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
float,
|
||||
SmallVector<const void*>,
|
||||
SmallVector<int32_t>,
|
||||
SmallVector<int64_t>>;
|
||||
@@ -85,19 +82,16 @@ class JitModule {
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder,
|
||||
bool cache);
|
||||
const KernelBuilder& builder);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
CUfunction get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
@@ -105,7 +99,6 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder,
|
||||
bool use_disk_cache = true);
|
||||
const KernelBuilder& builder);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/gemms/cutlass_gemm.h"
|
||||
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||
#include "mlx/backend/cuda/gemms/simple_gemm.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -13,14 +11,8 @@
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
int get_test_gemm() {
|
||||
static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
|
||||
return t;
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, array>
|
||||
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
@@ -103,21 +95,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
||||
b_transposed && batch_count == 1 && get_test_gemm() == 1) {
|
||||
cu::simple_gemm(a, b, out, M, N, K, encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
||||
b_transposed && batch_count == 1 && get_test_gemm() == 2) {
|
||||
cu::cutlass_gemm(a, b, out, M, N, K, encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
cu::Matmul matmul(
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
@@ -131,7 +111,14 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
|
||||
if ((batch_count / batch_shape.back()) == 1) {
|
||||
matmul.run(encoder, out, a, b);
|
||||
return;
|
||||
}
|
||||
|
||||
matmul.run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -199,7 +186,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
|
||||
CublasGemm gemm(
|
||||
cu::Matmul matmul(
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
@@ -215,7 +202,12 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back(),
|
||||
c_batch_strides.back());
|
||||
gemm.run(
|
||||
|
||||
if ((batch_count / batch_shape.back()) == 1) {
|
||||
matmul.run(encoder, out, a, b, c, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
matmul.run_batched(
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
|
||||
@@ -1,47 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
int) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<Shape>&,
|
||||
const std::vector<Dtype>&,
|
||||
const std::vector<ScalarArg>&,
|
||||
std::tuple<int, int, int>,
|
||||
std::tuple<int, int, int>,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -41,6 +41,10 @@ NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
|
||||
@@ -8,13 +8,19 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
// cudnn_frontend.h redefines this macro.
|
||||
#undef CHECK_CUDA_ERROR
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
|
||||
namespace fe = cudnn_frontend;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
@@ -639,6 +645,294 @@ void sdpa_vector_fallback(
|
||||
}
|
||||
}
|
||||
|
||||
struct SDPACacheKey {
|
||||
int device_id;
|
||||
fe::DataType_t cudnn_type;
|
||||
|
||||
int B;
|
||||
int H;
|
||||
int D;
|
||||
|
||||
int qL;
|
||||
int kL;
|
||||
|
||||
int gqa_factor;
|
||||
float scale;
|
||||
|
||||
int64_t Q_strides[3];
|
||||
int64_t K_strides[3];
|
||||
int64_t V_strides[3];
|
||||
int64_t O_strides[3];
|
||||
|
||||
bool generate_stats;
|
||||
bool causal_mask;
|
||||
};
|
||||
|
||||
auto& sdpa_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, std::shared_ptr<fe::graph::Graph>>
|
||||
cache(
|
||||
/* capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
#define Q_UID 1
|
||||
#define K_UID 2
|
||||
#define V_UID 3
|
||||
#define O_UID 4
|
||||
#define STATS_UID 5
|
||||
|
||||
std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
const SDPACacheKey& cache_key) {
|
||||
// Check if graph has already been fully built
|
||||
if (auto it = sdpa_cache().find(cache_key); it != sdpa_cache().end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Set up new graph
|
||||
auto graph = std::make_shared<fe::graph::Graph>();
|
||||
|
||||
graph->set_io_data_type(cache_key.cudnn_type)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto Q = graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("Q")
|
||||
.set_uid(Q_UID)
|
||||
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
|
||||
.set_stride(
|
||||
{cache_key.Q_strides[0],
|
||||
cache_key.Q_strides[1],
|
||||
cache_key.Q_strides[2],
|
||||
1}));
|
||||
|
||||
int h_kv = cache_key.H / cache_key.gqa_factor;
|
||||
auto K =
|
||||
graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("K")
|
||||
.set_uid(K_UID)
|
||||
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
|
||||
.set_stride(
|
||||
{cache_key.K_strides[0],
|
||||
cache_key.K_strides[1],
|
||||
cache_key.V_strides[2],
|
||||
1}));
|
||||
|
||||
auto V =
|
||||
graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("V")
|
||||
.set_uid(V_UID)
|
||||
.set_dim({cache_key.B, h_kv, cache_key.kL, cache_key.D})
|
||||
.set_stride(
|
||||
{cache_key.V_strides[0],
|
||||
cache_key.V_strides[1],
|
||||
cache_key.V_strides[2],
|
||||
1}));
|
||||
|
||||
auto sdpa_options = fe::graph::SDPA_attributes()
|
||||
.set_name("flash_attention")
|
||||
.set_is_inference(!cache_key.generate_stats)
|
||||
.set_attn_scale(cache_key.scale);
|
||||
|
||||
if (cache_key.causal_mask && cache_key.qL > 1) {
|
||||
sdpa_options.set_diagonal_alignment(fe::DiagonalAlignment_t::TOP_LEFT)
|
||||
.set_diagonal_band_right_bound(0);
|
||||
}
|
||||
|
||||
auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);
|
||||
|
||||
O->set_output(true)
|
||||
.set_uid(O_UID)
|
||||
.set_dim({cache_key.B, cache_key.H, cache_key.qL, cache_key.D})
|
||||
.set_stride(
|
||||
{cache_key.O_strides[0],
|
||||
cache_key.O_strides[1],
|
||||
cache_key.O_strides[2],
|
||||
1});
|
||||
|
||||
if (cache_key.generate_stats) {
|
||||
Stats->set_output(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT)
|
||||
.set_uid(STATS_UID);
|
||||
}
|
||||
|
||||
// Build and Validate cudnn graph
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
||||
if (cudnnGetVersion() < 90600) {
|
||||
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
|
||||
if (!build_status.is_good()) {
|
||||
throw std::runtime_error(
|
||||
"Unable to build cudnn graph for attention."
|
||||
" Failed with message: " +
|
||||
build_status.get_message());
|
||||
}
|
||||
|
||||
} else {
|
||||
auto val_status = graph->validate();
|
||||
auto op_status = graph->build_operation_graph(handle);
|
||||
|
||||
auto plan_stauts =
|
||||
graph->create_execution_plans({cudnn_frontend::HeurMode_t::A});
|
||||
if (!plan_stauts.is_good()) {
|
||||
throw std::runtime_error(
|
||||
"Unable to create exec plan for cudnn attention."
|
||||
" Failed with message: " +
|
||||
plan_stauts.get_message());
|
||||
}
|
||||
|
||||
graph->select_behavior_notes(
|
||||
{cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
|
||||
auto support_status = graph->check_support(handle);
|
||||
if (!support_status.is_good()) {
|
||||
throw std::runtime_error(
|
||||
"No cuda graph support for cudnn attention."
|
||||
" Failed with message: " +
|
||||
support_status.get_message());
|
||||
}
|
||||
|
||||
auto build_status = graph->build_plans(handle);
|
||||
if (!build_status.is_good()) {
|
||||
throw std::runtime_error(
|
||||
"Unable to build cudnn graph for attention."
|
||||
" Failed with message: " +
|
||||
build_status.get_message());
|
||||
}
|
||||
}
|
||||
|
||||
auto [it, _] = sdpa_cache().emplace(cache_key, graph);
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return fe::DataType_t::INT8;
|
||||
case int32:
|
||||
return fe::DataType_t::INT32;
|
||||
case uint8:
|
||||
return fe::DataType_t::UINT8;
|
||||
case float16:
|
||||
return fe::DataType_t::HALF;
|
||||
case bfloat16:
|
||||
return fe::DataType_t::BFLOAT16;
|
||||
case float32:
|
||||
return fe::DataType_t::FLOAT;
|
||||
case float64:
|
||||
return fe::DataType_t::DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in SDPA: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
void sdpa_cudnn(
|
||||
const Stream& s,
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
auto cudnn_type = dtype_to_cudnn_type(q.dtype());
|
||||
|
||||
int B = q.shape(0);
|
||||
int H = q.shape(1);
|
||||
int D = q.shape(3);
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
|
||||
int qL = q.shape(2);
|
||||
int kL = k.shape(2);
|
||||
|
||||
SDPACacheKey cache_key{
|
||||
/* int device_id = */ encoder.device().cuda_device(),
|
||||
/* fe::DataType_t cudnn_type = */ cudnn_type,
|
||||
|
||||
/* int B = */ B,
|
||||
/* int H = */ H,
|
||||
/* int D = */ D,
|
||||
|
||||
/* int qL = */ qL,
|
||||
/* int kL = */ kL,
|
||||
|
||||
/* int gqa_factor = */ gqa_factor,
|
||||
/* float scale = */ scale,
|
||||
|
||||
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)},
|
||||
|
||||
/* bool generate_stats = */ false,
|
||||
/* bool causal_mask = */ do_causal_};
|
||||
|
||||
auto graph = get_sdpa_forward_graph(encoder, cache_key);
|
||||
|
||||
int64_t workspace_size = 0;
|
||||
auto workspace_status = graph->get_workspace_size(workspace_size);
|
||||
if (!workspace_status.is_good()) {
|
||||
throw std::runtime_error("Unable to get workspace for cudnn attention.");
|
||||
}
|
||||
|
||||
array workspace(
|
||||
allocator::malloc(workspace_size), {int(workspace_size)}, uint8);
|
||||
auto workspace_ptr = workspace.data<void>();
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
{Q_UID, const_cast<void*>(q.data<void>())},
|
||||
{K_UID, const_cast<void*>(k.data<void>())},
|
||||
{V_UID, const_cast<void*>(v.data<void>())},
|
||||
{O_UID, o.data<void>()}};
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
||||
if (cudnnGetVersion() < 90600) {
|
||||
auto capture = encoder.capture_context();
|
||||
auto exec_status = graph->execute(handle, variant_pack, workspace_ptr);
|
||||
|
||||
if (!exec_status.is_good()) {
|
||||
capture.discard = true;
|
||||
throw std::runtime_error(
|
||||
"Unable to execute cudnn attention."
|
||||
" Failed with message: " +
|
||||
exec_status.get_message());
|
||||
}
|
||||
} else {
|
||||
cudaGraph_t cu_graph;
|
||||
cudaGraphCreate(&cu_graph, 0);
|
||||
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&cu_graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||
|
||||
auto cu_graph_status = graph->populate_cuda_graph(
|
||||
handle, variant_pack, workspace_ptr, cu_graph);
|
||||
|
||||
if (!cu_graph_status.is_good()) {
|
||||
throw std::runtime_error(
|
||||
"Unable to add cuda graph for cudnn attention."
|
||||
" Failed with message: " +
|
||||
cu_graph_status.get_message());
|
||||
}
|
||||
|
||||
encoder.add_graph_node(cu_graph);
|
||||
}
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace fast {
|
||||
@@ -651,9 +945,6 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
@@ -669,7 +960,15 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
const bool supported_vector_config =
|
||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||
|
||||
const bool supported_config = supported_vector_config;
|
||||
auto& cu_device = cu::device(s.device);
|
||||
|
||||
const bool supported_matrix_config = query_sequence_length > 4 &&
|
||||
cu_device.compute_capability_major() >= 8 &&
|
||||
query_sequence_length == key_sequence_length &&
|
||||
(q.dtype() == float16 || q.dtype() == bfloat16);
|
||||
|
||||
const bool supported_config =
|
||||
(supported_matrix_config || supported_vector_config);
|
||||
|
||||
return has_arr_mask || !supported_config;
|
||||
}
|
||||
@@ -703,6 +1002,10 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(-1) == 1;
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) < 4) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
@@ -756,7 +1059,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ o.shape(2) == 1,
|
||||
/* bool row_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ 0,
|
||||
};
|
||||
|
||||
@@ -770,9 +1073,35 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
}
|
||||
|
||||
// Full attention mode should never reach here
|
||||
// Full attention mode
|
||||
else {
|
||||
throw std::runtime_error("Doesn't support matrix yet.");
|
||||
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
int64_t str_oD = 1;
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ 0,
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
return sdpa_cudnn(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
@@ -4,189 +4,95 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T, int BM, int BN, int BK, int WM, int WN>
|
||||
__device__ inline void gemm_ab_t(
|
||||
RegisterTile<float, BM / WM, BN / WN>& C,
|
||||
SharedTile<T, BM, BK>& As,
|
||||
SharedTile<T, BN, BK>& Bs,
|
||||
RegisterTileLoader<SharedTile<T, BM, BK>>& rloader_a,
|
||||
RegisterTileLoader<SharedTile<T, BN, BK>>& rloader_b) {
|
||||
RegisterTile<T, BM / WM, 16> A[2];
|
||||
RegisterTile<T, BN / WN, 16> B[2];
|
||||
|
||||
rloader_a.load(A[0], As.base_addr(), 0);
|
||||
rloader_b.load(B[0], Bs.base_addr(), 0);
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 1; k < BK / 16; k++) {
|
||||
rloader_a.load(A[k & 1], As.base_addr(), k);
|
||||
rloader_b.load(B[k & 1], Bs.base_addr(), k);
|
||||
|
||||
mma_t(C, A[(k - 1) & 1], B[(k - 1) & 1]);
|
||||
}
|
||||
mma_t(C, A[(BK / 16 - 1) & 1], B[(BK / 16 - 1) & 1]);
|
||||
}
|
||||
|
||||
/**
|
||||
* An example gemm written with the utils.
|
||||
*
|
||||
* Computes A @ B.T when A and B are all aligned with the block sizes.
|
||||
*/
|
||||
// template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
|
||||
//__global__ __launch_bounds__(WM * WN * WARP_SIZE, 1)
|
||||
// void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
// constexpr int NUM_WARPS = WM * WN;
|
||||
// constexpr int WARP_STEP_M = BM / WM;
|
||||
// constexpr int WARP_STEP_N = BN / WN;
|
||||
//
|
||||
// // Precompute some offsets for each thread
|
||||
// const int warpid = threadIdx.x / 32;
|
||||
// const int laneid = threadIdx.x % 32;
|
||||
// const int wm = warpid / WN;
|
||||
// const int wn = warpid % WN;
|
||||
// const int offset_m = wm * WARP_STEP_M;
|
||||
// const int offset_n = wn * WARP_STEP_N;
|
||||
//
|
||||
// // Allocate shared memory
|
||||
// extern __shared__ char shmem[];
|
||||
// SharedTile<T, BM, BK>(&as)[PIPE] =
|
||||
// *(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
||||
// SharedTile<T, BN, BK>(&bs)[PIPE] =
|
||||
// *(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
||||
//
|
||||
// // Move the global pointers to the tile
|
||||
// a += blockIdx.y * BM * K;
|
||||
// b += blockIdx.x * BN * K;
|
||||
// y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||
//
|
||||
// // Make the loaders to/from SMEM
|
||||
// SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>> sloader_a(a, K);
|
||||
// SharedTileLoader<NUM_WARPS, SharedTile<T, BN, BK>> sloader_b(b, K);
|
||||
// RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
|
||||
// RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
|
||||
//
|
||||
// // Start the SM pipeline
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < PIPE - 1; i++) {
|
||||
// sloader_a.load_async(as[i].base_addr());
|
||||
// sloader_b.load_async(bs[i].base_addr());
|
||||
// cp_async_commit();
|
||||
// sloader_a.next();
|
||||
// sloader_b.next();
|
||||
// }
|
||||
//
|
||||
// // Allocate and zero the MMA accumulator
|
||||
// RegisterTile<float, BM / WM, BN / WN> C;
|
||||
// C.fill(0);
|
||||
//
|
||||
// // Matmul loop
|
||||
// int num_blocks = K / BK;
|
||||
// int sread = 0;
|
||||
// int swrite = PIPE - 1;
|
||||
// for (int i = 0; i < num_blocks; i++) {
|
||||
// cp_async_wait<PIPE - 1>();
|
||||
//
|
||||
// gemm_ab_t<T, BM, BN, BK, WM, WN>(
|
||||
// C, as[sread], bs[sread], rloader_a, rloader_b);
|
||||
//
|
||||
// sloader_a.load_async(as[swrite].base_addr());
|
||||
// sloader_b.load_async(bs[swrite].base_addr());
|
||||
// cp_async_commit();
|
||||
// sloader_a.next(i + PIPE < num_blocks);
|
||||
// sloader_b.next(i + PIPE < num_blocks);
|
||||
//
|
||||
// swrite = sread;
|
||||
// sread = (sread + 1) % PIPE;
|
||||
// }
|
||||
//
|
||||
// C.store_global(y, N, offset_m, offset_n);
|
||||
// }
|
||||
|
||||
template <typename T, int BM, int BN, int BK, int WM, int WN, int PIPE>
|
||||
__global__ __launch_bounds__(
|
||||
WM* WN* WARP_SIZE,
|
||||
1) void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
constexpr int NUM_WARPS = WM * WN;
|
||||
constexpr int WARP_STEP_M = BM / WM;
|
||||
constexpr int WARP_STEP_N = BN / WN;
|
||||
template <typename T, int BM, int BN, int BK>
|
||||
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||
constexpr int WARPS_M = 2;
|
||||
constexpr int WARPS_N = 2;
|
||||
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||
constexpr int WARP_STEP_N = BN / WARPS_N;
|
||||
|
||||
// Precompute some offsets for each thread
|
||||
const int warpid = threadIdx.x / 32;
|
||||
const int laneid = threadIdx.x % 32;
|
||||
const int wm = warpid / WN;
|
||||
const int wn = warpid % WN;
|
||||
const int wm = warpid / WARPS_N;
|
||||
const int wn = warpid % WARPS_N;
|
||||
const int offset_m = wm * WARP_STEP_M;
|
||||
const int offset_n = wn * WARP_STEP_N;
|
||||
|
||||
// Allocate shared memory
|
||||
extern __shared__ char shmem[];
|
||||
SharedTile<T, BM, BK>(&as)[PIPE] =
|
||||
*(SharedTile<T, BM, BK>(*)[PIPE])(&shmem[0]);
|
||||
SharedTile<T, BN, BK>(&bs)[PIPE] =
|
||||
*(SharedTile<T, BN, BK>(*)[PIPE])(&shmem[sizeof(T) * PIPE * BM * BK]);
|
||||
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
|
||||
SharedTile<T, BN, BK>(&bs)[2] =
|
||||
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
|
||||
|
||||
// Allocate registers for the MMA
|
||||
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||
|
||||
// Move the global pointers to the tile
|
||||
a += blockIdx.y * BM * K;
|
||||
b += blockIdx.x * BN * K;
|
||||
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||
|
||||
// Make the loaders to/from SMEM
|
||||
using sloader = SharedTileLoader<NUM_WARPS, SharedTile<T, BM, BK>>;
|
||||
constexpr int SSTEP = sloader::STEP_ROWS * sizeof(T) * BK;
|
||||
const int srow = threadIdx.x / sloader::NUM_LOADS_PER_ROW;
|
||||
const int scol =
|
||||
(threadIdx.x % sloader::NUM_LOADS_PER_ROW) * sloader::ELEMENTS_PER_LOAD;
|
||||
a += srow * K + scol;
|
||||
b += srow * K + scol;
|
||||
uint32_t sm_offsets[PIPE][2];
|
||||
MLX_UNROLL
|
||||
for (int s = 0; s < PIPE; s++) {
|
||||
sm_offsets[s][0] = as[s].loc(as[s].base_addr(), srow, scol);
|
||||
sm_offsets[s][1] = bs[s].loc(bs[s].base_addr(), srow, scol);
|
||||
}
|
||||
RegisterTileLoader<SharedTile<T, BM, BK>> rloader_a(offset_m, laneid);
|
||||
RegisterTileLoader<SharedTile<T, BN, BK>> rloader_b(offset_n, laneid);
|
||||
|
||||
// Start the SM pipeline
|
||||
MLX_UNROLL
|
||||
for (int s = 0; s < PIPE - 1; s++) {
|
||||
MLX_UNROLL
|
||||
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
|
||||
cp_async<16>(sm_offsets[s][0] + l * SSTEP, a);
|
||||
cp_async<16>(sm_offsets[s][1] + l * SSTEP, b);
|
||||
a += sloader::STEP_ROWS * K;
|
||||
b += sloader::STEP_ROWS * K;
|
||||
}
|
||||
cp_async_commit();
|
||||
}
|
||||
|
||||
// Allocate and zero the MMA accumulator
|
||||
RegisterTile<float, BM / WM, BN / WN> C;
|
||||
// Zero the accumulators
|
||||
C.fill(0);
|
||||
|
||||
// Matmul loop
|
||||
int num_blocks = K / BK;
|
||||
int sread = 0;
|
||||
int swrite = PIPE - 1;
|
||||
for (int i = 0; i < num_blocks; i++) {
|
||||
cp_async_wait<PIPE - 1>();
|
||||
// Start the SM pipeline
|
||||
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
|
||||
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
|
||||
cp_async_commit();
|
||||
|
||||
gemm_ab_t<T, BM, BN, BK, WM, WN>(
|
||||
C, as[sread], bs[sread], rloader_a, rloader_b);
|
||||
|
||||
if (false) {
|
||||
MLX_UNROLL
|
||||
for (int l = 0; l < sloader::NUM_LOADS_PER_THREAD; l++) {
|
||||
cp_async<16>(sm_offsets[swrite][0] + l * SSTEP, a);
|
||||
cp_async<16>(sm_offsets[swrite][1] + l * SSTEP, b);
|
||||
a += sloader::STEP_ROWS * K;
|
||||
b += sloader::STEP_ROWS * K;
|
||||
}
|
||||
}
|
||||
int tic = 0;
|
||||
for (int k_block = BK; k_block < K; k_block += BK) {
|
||||
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
|
||||
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
||||
cp_async_commit();
|
||||
cp_async_wait<1>();
|
||||
__syncthreads();
|
||||
|
||||
swrite = sread;
|
||||
sread = (sread + 1) % PIPE;
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
as[tic],
|
||||
as[tic].base_addr(),
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
bs[tic],
|
||||
bs[tic].base_addr(),
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
|
||||
tic ^= 1;
|
||||
}
|
||||
|
||||
// Empty the pipeline
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
as[tic],
|
||||
as[tic].base_addr(),
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
bs[tic],
|
||||
bs[tic].base_addr(),
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
|
||||
C.store_global(y, N, offset_m, offset_n);
|
||||
|
||||
@@ -223,10 +223,59 @@ struct RegisterTile {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A simple container of multiple Tile16x16.
|
||||
*
|
||||
* Provides utility functions for loading and manipulating collections of basic
|
||||
* tiles.
|
||||
*/
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
struct RegisterTile {
|
||||
static constexpr int ROWS = ROWS_;
|
||||
static constexpr int COLS = COLS_;
|
||||
static constexpr int TILES_X = COLS / 16;
|
||||
static constexpr int TILES_Y = ROWS / 16;
|
||||
|
||||
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||
|
||||
__device__ inline void fill(T v) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].fill(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tile>
|
||||
__device__ inline void
|
||||
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].load(
|
||||
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global(
|
||||
x + (row + i * 16) * N + col + j * 16, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
struct SharedTile {
|
||||
using value_type = T;
|
||||
|
||||
static constexpr int ROWS = ROWS_;
|
||||
static constexpr int COLS = COLS_;
|
||||
static constexpr int TILES_X = COLS / 16;
|
||||
@@ -268,26 +317,23 @@ struct SharedTile {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static inline uint32_t offset(int row, int col) {
|
||||
// Return the location of the element at (row, col) using the swizzle.
|
||||
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||
if constexpr (swizzle_bytes > 0) {
|
||||
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||
const int outer_idx = col / subtile_cols;
|
||||
const uint32_t addr = sizeof(T) *
|
||||
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||
col % subtile_cols);
|
||||
const uint32_t addr = ptr +
|
||||
sizeof(T) *
|
||||
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||
col % subtile_cols);
|
||||
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||
return (addr ^ swizzle);
|
||||
} else {
|
||||
return sizeof(T) * (row * COLS + col);
|
||||
return ptr + sizeof(T) * (row * COLS + col);
|
||||
}
|
||||
}
|
||||
|
||||
// Return the location of the element at (row, col) using the swizzle.
|
||||
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||
return ptr + offset(row, col);
|
||||
}
|
||||
|
||||
// Convenience functions to edit elements going through the swizzle.
|
||||
__device__ inline T& operator()(int row, int col) {
|
||||
return *ptr(data, row, col);
|
||||
@@ -318,76 +364,6 @@ struct SharedTile {
|
||||
}
|
||||
};
|
||||
|
||||
template <int NUM_WARPS, typename Tile>
|
||||
struct SharedTileLoader {
|
||||
using T = typename Tile::value_type;
|
||||
|
||||
static constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||
static constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||
static constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||
static constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||
static constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||
static constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||
|
||||
const T* x_;
|
||||
int N_;
|
||||
uint32_t offset_;
|
||||
|
||||
__device__ SharedTileLoader(const T* x, int N) : x_(x), N_(N) {
|
||||
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||
|
||||
x_ += row * N + col * ELEMENTS_PER_LOAD;
|
||||
offset_ = Tile::offset(row, col * ELEMENTS_PER_LOAD);
|
||||
}
|
||||
|
||||
__device__ inline void load_async(uint32_t base_address) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||
cp_async<16>(
|
||||
base_address + offset_ + i * STEP_ROWS * sizeof(T) * Tile::COLS,
|
||||
x_ + i * STEP_ROWS * N_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void next() {
|
||||
x_ += Tile::COLS;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tile>
|
||||
struct RegisterTileLoader {
|
||||
using T = typename Tile::value_type;
|
||||
|
||||
uint32_t offset_[Tile::COLS / 16];
|
||||
|
||||
__device__ RegisterTileLoader(int offset_row, int laneid) {
|
||||
const int row = offset_row + laneid & 15;
|
||||
const int col = (laneid >> 4) << 3;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < Tile::COLS / 16; i++) {
|
||||
offset_[i] = Tile::offset(row, col + i * 16);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ROWS, int COLS>
|
||||
__device__ inline void
|
||||
load(RegisterTile<T, ROWS, COLS>& x, uint32_t base_address, int col) {
|
||||
constexpr int TILES_Y = RegisterTile<T, ROWS, COLS>::TILES_Y;
|
||||
constexpr int TILES_X = RegisterTile<T, ROWS, COLS>::TILES_X;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
x.data[i * TILES_X + j].load(
|
||||
base_address + offset_[j + col] + i * 16 * Tile::COLS * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Load the tile from global memory by loading 16 bytes at a time and storing
|
||||
* them immediately.
|
||||
|
||||
@@ -21,15 +21,15 @@ __device__ inline void cp_async(uint32_t row_address, const T* x) {
|
||||
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||
if constexpr (N == 16) {
|
||||
asm volatile(
|
||||
"cp.async.cg.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||
"l"(reinterpret_cast<const int4*>(x)));
|
||||
} else if constexpr (N == 8) {
|
||||
asm volatile(
|
||||
"cp.async.cg.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
||||
"l"(reinterpret_cast<const int2*>(x)));
|
||||
} else if constexpr (N == 4) {
|
||||
asm volatile(
|
||||
"cp.async.cg.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||
"l"(reinterpret_cast<const int*>(x)));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -39,98 +39,52 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename T, typename IdxT, int NDIM, int N_READS>
|
||||
template <typename Op, typename T, typename IdxT, int NDIM>
|
||||
__global__ void ternary_g_nd(
|
||||
const bool* a,
|
||||
const T* b,
|
||||
const T* c,
|
||||
T* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_strides) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
c_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[NDIM - 1];
|
||||
auto a_stride_x = a_strides[NDIM - 1];
|
||||
auto b_stride_x = b_strides[NDIM - 1];
|
||||
auto c_stride_x = c_strides[NDIM - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc_nd<NDIM>(
|
||||
index_rest * shape_x,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
c_strides.data());
|
||||
auto a_vec =
|
||||
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, false);
|
||||
auto b_vec =
|
||||
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, T(0));
|
||||
auto c_vec =
|
||||
load_vector<N_READS>(c + c_idx, index_x, shape_x, c_stride_x, T(0));
|
||||
|
||||
AlignedVector<T, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename T, typename IdxT, int N_READS>
|
||||
template <typename Op, typename T, typename IdxT>
|
||||
__global__ void ternary_g(
|
||||
const bool* a,
|
||||
const T* b,
|
||||
const T* c,
|
||||
T* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
const __grid_constant__ Strides c_strides,
|
||||
int ndim) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc(
|
||||
index,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
c_strides.data(),
|
||||
ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto a_stride_x = a_strides[ndim - 1];
|
||||
auto b_stride_x = b_strides[ndim - 1];
|
||||
auto c_stride_x = c_strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto [a_idx, b_idx, c_idx] = elem_to_loc(
|
||||
index_rest * shape_x,
|
||||
shape.data(),
|
||||
a_strides.data(),
|
||||
b_strides.data(),
|
||||
c_strides.data(),
|
||||
ndim);
|
||||
auto a_vec =
|
||||
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, false);
|
||||
auto b_vec =
|
||||
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, T(0));
|
||||
auto c_vec =
|
||||
load_vector<N_READS>(c + c_idx, index_x, shape_x, c_stride_x, T(0));
|
||||
|
||||
AlignedVector<T, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
@@ -169,55 +123,36 @@ void ternary_op_gpu_inplace(
|
||||
auto& b_strides = strides[1];
|
||||
auto& c_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto kernel =
|
||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant(), 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel =
|
||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant(), 4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
rest,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides),
|
||||
const_param<dims_constant()>(c_strides));
|
||||
});
|
||||
} else {
|
||||
auto kernel = cu::ternary_g<Op, DType, IdxT, 1>;
|
||||
if (work_per_thread == 4) {
|
||||
kernel = cu::ternary_g<Op, DType, IdxT, 4>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::ternary_g<Op, DType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
rest,
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
|
||||
@@ -37,36 +37,19 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void unary_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size_rest,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides,
|
||||
int ndim) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
|
||||
out[index] = Op{}(in[idx]);
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto stride_x = strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto idx =
|
||||
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
|
||||
auto in_vec =
|
||||
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
@@ -144,7 +127,8 @@ void unary_op_gpu_inplace(
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
if (contig) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(OutType);
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large, N_READS);
|
||||
encoder.add_kernel_node(
|
||||
@@ -158,30 +142,18 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
cu::unary_g<Op, InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
rest,
|
||||
out.data_size(),
|
||||
const_param(shape),
|
||||
const_param(strides),
|
||||
ndim);
|
||||
shape.size());
|
||||
}
|
||||
});
|
||||
} else {
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/abs.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccos.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccosh.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsin.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsinh.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctanh.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_invert.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ceil.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conjugate.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cos.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/imag.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log1p.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_not.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/negative.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/real.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/round.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sigmoid.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sign.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sin.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sinh.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sqrt.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/square.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tan.cu
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tanh.cu)
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Abs)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(ArcCos)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(ArcCosh)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(ArcSin)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(ArcSinh)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(ArcTan)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(ArcTanh)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(BitwiseInvert)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Ceil)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Conjugate)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Cos)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Cosh)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Erf)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(ErfInv)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Exp)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Expm1)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Floor)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Imag)
|
||||
} // namespace mlx::core
|
||||
@@ -1,21 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Log::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu<cu::Log>(inputs, out, name(), s);
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Log1p)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(LogicalNot)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Negative)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Real)
|
||||
} // namespace mlx::core
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Round::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu<cu::Round>(inputs, out, name(), s);
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Sigmoid)
|
||||
} // namespace mlx::core
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/unary/unary.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
UNARY_GPU(Sign)
|
||||
} // namespace mlx::core
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user