Compare commits

..

6 Commits

Author SHA1 Message Date
Awni Hannun
193cdcd81a Fix graph updating (#2857)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-02 17:12:24 -08:00
Awni Hannun
d8ceae7b77 Reduce JVP (#2854) 2025-12-02 16:17:47 -08:00
Awni Hannun
eff0e31f00 Fix export scatters (#2852)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-02 11:24:40 -08:00
Awni Hannun
6c5785bc2f use thread local cpature mode (#2850)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-01 19:02:47 -08:00
CCYeh
8879ee00eb Support more Numpy interfaces for masked_scatter (#2832) 2025-12-01 17:51:02 -08:00
Cheng
6e762fe2e2 [CUDA] Migrate conv code to new cuDNN APIs (#2847)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-02 07:55:43 +09:00
18 changed files with 505 additions and 724 deletions

View File

@@ -179,8 +179,8 @@ assignments, ``updates`` must provide at least as many elements as there are
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. No mask
broadcasting occurs.
- The mask shape must match the shape of the axes it indexes exactly. The only
exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full.
.. code-block:: shell

View File

@@ -15,19 +15,16 @@ namespace mlx::core {
namespace {
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
// Custom placeholder representing fallback kernel.
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
enum ConvBackendType {
CONV_FALLBACK,
CONV_FORWARD,
CONV_BACKWARD_INPUT,
CONV_BACKWARD_WEIGHT,
};
struct ConvCacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
fe::DataType_t cudnn_dtype;
std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
@@ -44,15 +41,13 @@ struct ConvCacheKey {
auto& conv_cache() {
static LRUBytesKeyCache<
ConvCacheKey,
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
std::pair<ConvBackendType, std::optional<DnnGraph>>>
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache;
}
auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
auto get_conv_settings(
ConvBackendType backend_type,
array& x,
array& w,
array& y,
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i];
}
return std::make_tuple(
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
}
}
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
std::optional<DnnGraph> build_conv_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
Dtype dtype,
array& x,
array& w,
array& y,
const SmallVector<int64_t>& stride,
const SmallVector<int64_t>& padding_lo,
const SmallVector<int64_t>& padding_hi,
const SmallVector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(dtype);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation) {
auto compute_dtype =
(dtype == float16 || dtype == bfloat16) ? float32 : dtype;
DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype);
auto x_ = graph.tensor_nchw("X", 'x', x);
auto w_ = graph.tensor_nchw("W", 'w', w);
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_cudnn_tensor_nchw('x', x))
.setwDesc(build_cudnn_tensor_nchw('w', w))
.setyDesc(build_cudnn_tensor_nchw('y', y))
.setcDesc(conv_desc)
.build();
auto set_options = [&](auto& options) {
options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))
.set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)
.set_stride(stride)
.set_pre_padding(padding_lo)
.set_post_padding(padding_hi)
.set_dilation(dilation);
};
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
throw;
}
std::shared_ptr<fe::graph::Tensor_attributes> y_;
if (backend_type == CONV_FORWARD) {
auto options = fe::graph::Conv_fprop_attributes();
set_options(options);
y_ = graph.conv_fprop(x_, w_, options);
} else if (backend_type == CONV_BACKWARD_INPUT) {
auto options = fe::graph::Conv_dgrad_attributes();
set_options(options);
y_ = graph.conv_dgrad(x_, w_, options);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
auto options = fe::graph::Conv_wgrad_attributes();
set_options(options);
y_ = graph.conv_wgrad(w_, x_, options);
}
graph.tensor_nchw(y_, 'y', y)->set_output(true);
if (graph.prepare().is_bad()) {
return std::nullopt;
}
graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
if (dtype == float32 && !env::enable_tf32()) {
graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});
}
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
@@ -181,7 +184,7 @@ array group_transpose(
// eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
array in,
array wt,
array out,
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
array& in,
array& wt,
array& intermediate_out,
@@ -297,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(wt),
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
if (plan) {
// Run cached plan.
auto& [backend_type, graph] = it->second;
if (graph) {
// Run cached graph.
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
encoder,
{
{'x', gpu_ptr<void>(in)},
{'w', gpu_ptr<void>(wt)},
{'y', gpu_ptr<void>(out)},
}));
} else {
// Run fallback kernel.
gemm_conv(
@@ -327,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
// There is no reliable way to deduce the proper cuDNN backend for the
// convolution, so we make a best guess and then try.
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
SmallVector<ConvBackendType, 2> try_backends;
if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);
@@ -345,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
}
// Try to build op graph.
cudnnBackendDescriptorType_t backend_type;
std::optional<cudnn_frontend::OperationGraph> op_graph;
ConvBackendType backend_type;
std::optional<DnnGraph> graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
auto [x, w, y] =
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(
try_backend,
x,
w,
@@ -361,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_conv_op_graph(
graph = build_conv_graph(
encoder,
try_backend,
dtype,
@@ -372,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_lo,
padding_hi,
dilation);
if (op_graph) {
if (graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
in = std::move(x);
wt = std::move(w);
out = std::move(y);
break;
}
}
if (op_graph) {
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (plan) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}
if (graph) {
register_args(encoder, backend_type, in, wt, out, out_);
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
encoder,
{
{'x', gpu_ptr<void>(in)},
{'w', gpu_ptr<void>(wt)},
{'y', gpu_ptr<void>(out)},
}));
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*graph)));
return;
}
// Use fallback kernel for settings not supported by cuDNN.

View File

@@ -5,6 +5,7 @@
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h>
namespace mlx::core {
@@ -12,10 +13,12 @@ namespace mlx::core {
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
void check_cudnn_error(const char* name, cudnnStatus_t err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>

View File

@@ -7,32 +7,26 @@ namespace mlx::core {
namespace {
// Create a cudnn tensor descriptor.
template <typename Vec>
inline cudnn_frontend::Tensor build_cudnn_tensor(
int64_t id,
const array& x,
const Vec& shape,
const Vec& strides) {
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
#define RETURN_IF_ERROR(cmd) \
if (auto ret = cmd; ret.is_bad()) { \
return ret; \
}
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
std::vector<int64_t> normalized_strides(const array& x) {
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
if (std::all_of(
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
strides.back() = 1;
return strides;
}
if (!x.flags().row_contiguous || x.ndim() < 2) {
return strides;
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
@@ -42,7 +36,9 @@ Strides normalized_strides(const array& x) {
}
// Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
inline auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
auto strides = normalized_strides(x);
assert(shape.size() >= 3);
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
@@ -51,226 +47,95 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
return std::make_tuple(std::move(shape), std::move(strides));
}
inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
}
// Return available engines for a |op_graph|.
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = true) {
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
sources.push_back([](auto& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
});
if (use_fallback) {
sources.push_back([&backend_type](auto& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
});
}
auto configs =
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
// Take |engine_configs| and |op_graph| and find a working execution plans
// from them.
std::optional<cudnn_frontend::ExecutionPlan>
find_cudnn_plan_from_engine_configs(
cudnnHandle_t handle,
const cudnn_frontend::EngineConfigList& engine_configs,
const cudnn_frontend::OperationGraph& op_graph) {
auto op_graph_tag = op_graph.getTag();
for (const auto& config : engine_configs) {
try {
return cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return std::nullopt;
}
// Prepare workspace and args to execute plan.
template <typename F>
bool prepare_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder), {workspace_size}, uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
return false;
}
return true;
}
} // namespace
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
fe::error_t DnnGraph::prepare() {
RETURN_IF_ERROR(validate());
try {
RETURN_IF_ERROR(build_operation_graph(handle_));
} catch (cudnn_frontend::cudnnException& error) {
// cuDNN bug: they did not catch all exceptions in the API.
return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};
}
RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));
return {};
}
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
fe::error_t DnnGraph::build() {
RETURN_IF_ERROR(check_support(handle_));
RETURN_IF_ERROR(build_plans(handle_));
return {};
}
fe::error_t DnnGraph::encode_graph(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack) {
cudnnSetStream(handle_, encoder.stream());
CudaGraph cuda_graph(encoder.device());
RETURN_IF_ERROR(populate_cuda_graph(
handle_, variant_pack, prepare_workspace(encoder), cuda_graph));
encoder.add_graph_node(cuda_graph);
return {};
}
fe::error_t DnnGraph::encode_capturing(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack) {
auto* workspace_ptr = prepare_workspace(encoder);
auto capture = encoder.capture_context();
cudnnSetStream(handle_, encoder.stream());
auto ret = execute(handle_, variant_pack, workspace_ptr);
if (ret.is_bad()) {
capture.discard = true;
}
return ret;
}
void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
return gpu_ptr<void>(workspace);
}
return nullptr;
}
void DnnGraph::set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& strides) {
tensor->set_uid(uid)
.set_alignment(get_alignment(x))
.set_data_type(dtype_to_cudnn_type(x.dtype()))
.set_dim(shape)
.set_stride(strides);
}
void DnnGraph::set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
set_tensor_attrs(
tensor,
uid,
x,
convert_vector<int64_t>(x.shape()),
normalized_strides(x));
}
void DnnGraph::set_tensor_attrs_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return build_cudnn_tensor(id, x, shape, strides);
set_tensor_attrs(tensor, uid, x, shape, strides);
}
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
if (x.ndim() == 0) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
}
if (x.ndim() == 1) {
int64_t s = x.shape(0);
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
SmallVector<int64_t, 4> strides = {s, 1, s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 3 || x.ndim() == 4) {
return build_cudnn_tensor_nchw(id, x);
}
throw std::runtime_error(
fmt::format("Unsupported array with {} dims.", x.ndim()));
}
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return cudnn_frontend::TensorBuilder()
.setDim(scalar_dims.size(), scalar_dims.data())
.setStrides(scalar_dims.size(), scalar_dims.data())
.setId(id)
.setAlignment(16)
.setDataType(dtype_to_cudnn_type(dtype))
.setByValue(true)
.build();
}
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
auto capture = encoder.capture_context();
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
return true;
});
}
#if CUDNN_VERSION >= 90500
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
if (!graph) {
graph = CudaGraph(encoder.device());
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
} else {
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
}
encoder.add_graph_node(graph);
return true;
});
}
#endif
} // namespace mlx::core

View File

@@ -2,25 +2,30 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <algorithm>
#include <array>
namespace mlx::core {
namespace cu {
class CommandEncoder;
}
namespace fe = cudnn_frontend;
#define CHECK_CUDNN_FE_ERROR(cmd) \
do { \
auto error = cmd; \
if (!error.is_good()) { \
throw std::runtime_error( \
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
} \
} while (0)
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
@@ -35,8 +40,31 @@ inline uint8_t get_alignment(const array& x) {
// Convert the type of elements in |vec| to |T|.
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
inline std::vector<T> convert_vector(const Vec& vec) {
return std::vector<T>(vec.begin(), vec.end());
}
// Map dtype to cudnn data type.
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return fe::DataType_t::INT8;
case int32:
return fe::DataType_t::INT32;
case uint8:
return fe::DataType_t::UINT8;
case float16:
return fe::DataType_t::HALF;
case bfloat16:
return fe::DataType_t::BFLOAT16;
case float32:
return fe::DataType_t::FLOAT;
case float64:
return fe::DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
}
}
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
@@ -55,111 +83,89 @@ inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
return result;
}
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(gpu_ptr<void>(arr));
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
inline void* get_data_ptr(T& scalar) {
return &scalar;
}
// Return an array filled with data pointers of args.
template <typename... Args>
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
return {get_data_ptr(args)...};
}
// Map dtype to cudnn data type.
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
// Extends cuDNN graph with helpers.
class DnnGraph : public fe::graph::Graph {
public:
DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
: handle_(handle) {
set_io_data_type(dtype_to_cudnn_type(io_dtype));
set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
}
}
// Create a tensor descriptor from |x|.
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
// Create a cuDNN tensor description from MLX array |x|.
auto& tensor(
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
int64_t uid,
const array& x) {
set_tensor_attrs(attrs, uid, x);
return attrs;
}
auto tensor(const char* name, int64_t uid, const array& x) {
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
tensor(attrs, uid, x);
return attrs;
}
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
// Create a cuDNN tensor description from MLX array |x|, and transpose it from
// NHWC layout to NCHW.
auto& tensor_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
int64_t uid,
const array& x) {
set_tensor_attrs_nchw(attrs, uid, x);
return attrs;
}
auto tensor_nchw(const char* name, int64_t uid, const array& x) {
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
tensor_nchw(attrs, uid, x);
return attrs;
}
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
// from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
// Create a cuDNN tensor for scalar.
auto scalar(const char* name, int64_t uid, Dtype dtype) {
return Graph::tensor(fe::graph::Tensor_attributes()
.set_name(name)
.set_uid(uid)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(dtype_to_cudnn_type(dtype)));
}
// Create a 4D scalar tensor descriptor, which is passed by value.
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
// Call this before setting notes.
fe::error_t prepare();
// Call this after setting notes.
fe::error_t build();
// Find a working plan for |op_graph|.
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph);
// Add cuDNN graph to CUDA graph, using native CUDA graph API.
fe::error_t encode_graph(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack);
// Add cuDNN graph to CUDA graph, using stream capture.
fe::error_t encode_capturing(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack);
// Encode the plan to command buffer by capturing.
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs);
private:
void* prepare_workspace(cu::CommandEncoder& encoder);
#if CUDNN_VERSION >= 90500
// Encode the plan to command buffer by using native graph api of cudnn. If the
// |graph| is empty it will be populated, otherwise it will be updated.
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs);
#endif
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& strides);
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x);
void set_tensor_attrs_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x);
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_capturing(
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
}
#if CUDNN_VERSION >= 90500
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_graph_api(
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
}
#endif
cudnnHandle_t handle_;
};
} // namespace mlx::core

View File

@@ -14,15 +14,6 @@ namespace mlx::core::cu {
namespace {
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
bool use_cuda_graphs() {
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
return use_graphs;
@@ -96,7 +87,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
return;
}
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
@@ -341,21 +332,30 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
for (const auto& node : nodes) {
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type != cudaGraphNodeTypeKernel) {
if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel
if (num_nodes > 1) {
return false;
}
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
return is_graph_updatable(child, cluster_dim_x);
} else if (type != cudaGraphNodeTypeKernel) {
return false;
} else {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only dim.x can be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
return false;
}
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
}
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only dim.x can be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
return false;
}
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
}
return true;
}
@@ -371,7 +371,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
}
cudaGraphNode_t node;
int cluster_dim_x = 0;
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(
GraphNode{node, "G" + std::to_string(cluster_dim_x)});

View File

@@ -10,46 +10,8 @@
namespace mlx::core {
namespace fe = cudnn_frontend;
namespace {
#define CHECK_CUDNN_FE_ERROR(cmd) \
do { \
auto error = cmd; \
if (!error.is_good()) { \
throw std::runtime_error( \
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
} \
} while (0)
std::vector<int64_t> normalized_strides(const array& x) {
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
if (std::all_of(
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
strides.back() = 1;
return strides;
}
if (!x.flags().row_contiguous || x.ndim() < 2) {
return strides;
}
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
tensor->set_uid(uid)
.set_dim({x.shape().begin(), x.shape().end()})
.set_stride(normalized_strides(x));
}
array prepare_sdpa_input(const array& x, Stream s) {
// SDPA kernel's requirements on inputs:
// 1. last dim's stride be 1;
@@ -99,7 +61,7 @@ constexpr int QKV_NDIM = 4;
struct SDPACacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
fe::DataType_t cudnn_dtype;
std::array<int, QKV_NDIM> q_shape;
std::array<int, QKV_NDIM> k_shape;
std::array<int, QKV_NDIM> v_shape;
@@ -143,13 +105,13 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
}
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
return cache;
}
auto& sdpa_backward_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
return cache;
}
@@ -169,7 +131,7 @@ enum UIDS {
D_O,
};
fe::graph::Graph build_sdpa_graph(
DnnGraph build_sdpa_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
@@ -179,34 +141,15 @@ fe::graph::Graph build_sdpa_graph(
bool output_logsumexp,
const array& o,
const array& stats) {
auto dtype = fe::DataType_t::HALF;
if (q.dtype() == bfloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
DnnGraph graph(handle, q.dtype());
fe::graph::Graph graph;
graph.set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
set_tensor_attrs(q_, Q, q);
set_tensor_attrs(k_, K, k);
set_tensor_attrs(v_, V, v);
auto scale = graph.tensor(fe::graph::Tensor_attributes()
.set_name("Scale")
.set_uid(SCALE)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto q_ = graph.tensor("Q", Q, q);
auto k_ = graph.tensor("K", K, k);
auto v_ = graph.tensor("V", V, v);
auto options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(scale)
.set_attn_scale(graph.scalar("Scale", SCALE, float32))
.set_generate_stats(output_logsumexp);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
@@ -216,31 +159,23 @@ fe::graph::Graph build_sdpa_graph(
}
}
if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr);
options.set_bias(bias_);
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
}
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
o_->set_output(true);
set_tensor_attrs(o_, O, o);
graph.tensor(o_, O, o)->set_output(true);
if (output_logsumexp) {
stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT);
set_tensor_attrs(stats_, STATS, stats);
graph.tensor(stats_, STATS, stats)->set_output(true);
}
CHECK_CUDNN_FE_ERROR(graph.validate());
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
CHECK_CUDNN_FE_ERROR(graph.prepare());
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
fe::graph::Graph build_sdpa_backward_graph(
DnnGraph build_sdpa_backward_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
@@ -253,42 +188,18 @@ fe::graph::Graph build_sdpa_backward_graph(
array& d_q,
array& d_k,
array& d_v) {
auto dtype = fe::DataType_t::HALF;
if (q.dtype() == bfloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
DnnGraph graph(handle, q.dtype());
fe::graph::Graph graph;
graph.set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O"));
auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O"));
auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS"));
set_tensor_attrs(q_, Q, q);
set_tensor_attrs(k_, K, k);
set_tensor_attrs(v_, V, v);
set_tensor_attrs(o_, O, o);
set_tensor_attrs(d_o_, D_O, d_o);
set_tensor_attrs(stats_, STATS, stats);
stats_->set_data_type(fe::DataType_t::FLOAT);
auto scale = graph.tensor(fe::graph::Tensor_attributes()
.set_name("Scale")
.set_uid(SCALE)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto q_ = graph.tensor("Q", Q, q);
auto k_ = graph.tensor("K", K, k);
auto v_ = graph.tensor("V", V, v);
auto o_ = graph.tensor("O", O, o);
auto d_o_ = graph.tensor("D_O", D_O, d_o);
auto stats_ = graph.tensor("STATS", STATS, stats);
auto options = fe::graph::SDPA_backward_attributes()
.set_name("sdpa_backward_cudnn")
.set_attn_scale(scale)
.set_attn_scale(scale);
.set_attn_scale(graph.scalar("Scale", SCALE, float32));
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
@@ -297,56 +208,22 @@ fe::graph::Graph build_sdpa_backward_graph(
}
}
if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr);
options.set_bias(bias_);
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
}
auto [d_q_, d_k_, d_v_] =
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
d_q_->set_output(true);
d_k_->set_output(true);
d_v_->set_output(true);
set_tensor_attrs(d_q_, D_Q, d_q);
set_tensor_attrs(d_k_, D_K, d_k);
set_tensor_attrs(d_v_, D_V, d_v);
graph.tensor(d_q_, D_Q, d_q)->set_output(true);
graph.tensor(d_k_, D_K, d_k)->set_output(true);
graph.tensor(d_v_, D_V, d_v)->set_output(true);
CHECK_CUDNN_FE_ERROR(graph.validate());
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
CHECK_CUDNN_FE_ERROR(graph.prepare());
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
void execute_graph(
cu::CommandEncoder& encoder,
cudnnHandle_t handle,
fe::graph::Graph& graph,
std::unordered_map<int64_t, void*>& variant_pack) {
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
cudnnSetStream(handle, encoder.stream());
CudaGraph cuda_graph(encoder.device());
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
handle, variant_pack, workspace_ptr, cuda_graph));
encoder.add_graph_node(cuda_graph);
}
} // namespace
bool supports_sdpa_cudnn(
@@ -420,19 +297,19 @@ void sdpa_cudnn(
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, const_cast<void*>(gpu_ptr<void>(q))},
{K, const_cast<void*>(gpu_ptr<void>(k))},
{V, const_cast<void*>(gpu_ptr<void>(v))},
{Q, gpu_ptr<void>(q)},
{K, gpu_ptr<void>(k)},
{V, gpu_ptr<void>(v)},
{SCALE, &scale},
{O, gpu_ptr<void>(o)}};
if (mask_arr) {
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
}
if (output_logsumexp) {
variant_pack[STATS] = gpu_ptr<void>(stats);
}
execute_graph(encoder, handle, graph, variant_pack);
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
}
void sdpa_backward_cudnn(
@@ -480,21 +357,21 @@ void sdpa_backward_cudnn(
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, const_cast<void*>(gpu_ptr<void>(q))},
{K, const_cast<void*>(gpu_ptr<void>(k))},
{V, const_cast<void*>(gpu_ptr<void>(v))},
{Q, gpu_ptr<void>(q)},
{K, gpu_ptr<void>(k)},
{V, gpu_ptr<void>(v)},
{SCALE, &scale},
{O, const_cast<void*>(gpu_ptr<void>(o))},
{STATS, const_cast<void*>(gpu_ptr<void>(stats))},
{D_O, const_cast<void*>(gpu_ptr<void>(d_o))},
{O, gpu_ptr<void>(o)},
{STATS, gpu_ptr<void>(stats)},
{D_O, gpu_ptr<void>(d_o)},
{D_Q, gpu_ptr<void>(d_q)},
{D_K, gpu_ptr<void>(d_k)},
{D_V, gpu_ptr<void>(d_v)}};
if (mask_arr) {
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
}
execute_graph(encoder, handle, graph, variant_pack);
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
}
// Defined in scaled_dot_product_attention.cu file.

View File

@@ -32,6 +32,13 @@ void check_cuda_error(const char* name, CUresult err) {
}
}
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
const char* dtype_to_cuda_type(const Dtype& dtype) {
switch (dtype) {
case bool_:

View File

@@ -31,8 +31,10 @@ inline T* gpu_ptr(array& arr) {
arr.offset());
}
// For const array, keep constness in pointer unless it is untyped.
template <typename T>
inline const T* gpu_ptr(const array& arr) {
inline std::conditional_t<std::is_same_v<T, void>, void*, const T*> gpu_ptr(
const array& arr) {
return gpu_ptr<T>(const_cast<array&>(arr));
}

View File

@@ -16,77 +16,7 @@ INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
CCC="xcrun -sdk macosx metal -x metal"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
declare -a HDRS_LIST=($HDRS)
declare -a HDRS_STACK=()
declare -a HDRS_SORTED=()
length=${#HDRS_LIST[@]}
HDRS_LIST+=(".")
for ((i=0; i<${length}; i+=2));
do
header="${HDRS_LIST[$i+1]#$SRC_DIR/}"
str_this="${HDRS_LIST[$i]}"
str_next="${HDRS_LIST[$i + 2]}"
depth_this=${#str_this}
depth_next=${#str_next}
# If we have a dependency then we stack it
if [ $depth_next -gt $depth_this ]; then
HDRS_STACK=($header ${HDRS_STACK[@]})
# If we are done with this level
else
# We add the header to out list
HDRS_SORTED+=($header)
# Pop the stacked up dependencies
pop_len=$((depth_this - depth_next))
for popped_header in "${HDRS_STACK[@]:0:$pop_len}"
do
HDRS_SORTED+=($popped_header)
done
HDRS_STACK=(${HDRS_STACK[@]:$pop_len})
fi
done
HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}")
CONTENT=$(
echo "// Copyright © 2025 Apple Inc."
echo ""
echo "// Auto generated source for $INPUT_FILE"
echo ""
for header in "${HDRS_SORTED[@]}"
do
echo "///////////////////////////////////////////////////////////////////////////////"
echo "// Contents from \"${header}\""
echo "///////////////////////////////////////////////////////////////////////////////"
echo ""
echo "#line 1 \"${header}\""
grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}"
echo ""
done
echo "///////////////////////////////////////////////////////////////////////////////"
)
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
namespace mlx::core::metal {

View File

@@ -382,6 +382,7 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(LogicalOr),
SERIALIZE_PRIMITIVE(LogAddExp),
SERIALIZE_PRIMITIVE(LogSumExp),
SERIALIZE_PRIMITIVE(MaskedScatter),
SERIALIZE_PRIMITIVE(Matmul),
SERIALIZE_PRIMITIVE(Maximum),
SERIALIZE_PRIMITIVE(Minimum),

View File

@@ -3466,10 +3466,8 @@ array masked_scatter(
if (mask.dtype() != bool_) {
throw std::invalid_argument("[masked_scatter] The mask has to be boolean.");
}
if (mask.ndim() == 0) {
throw std::invalid_argument(
"[masked_scatter] Scalar masks are not supported.");
} else if (mask.ndim() > a.ndim()) {
if (mask.ndim() > a.ndim()) {
throw std::invalid_argument(
"[masked_scatter] The mask cannot have more dimensions than the target.");
}

View File

@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
}
}
std::vector<array> Reduce::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto in = primals[0];
auto s = stream();
auto grad_op = [&s, reduce_type = reduce_type_](
const array& x, const array& tan, int axis) {
if (reduce_type == Reduce::Min) {
auto idx = argmin(x, axis, true, s);
return take_along_axis(tan, idx, axis, s);
} else if (reduce_type == Reduce::Max) {
auto idx = argmax(x, axis, true, s);
return take_along_axis(tan, idx, axis, s);
} else {
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
auto out = multiply(multiply(p1, p2, s), tan, s);
return sum(out, axis, true, s);
}
};
auto tan = tangents[0];
if (reduce_type_ == Reduce::Sum) {
return {sum(tan, axes_, true, s)};
} else {
if (axes_.size() > 1) {
std::vector<int> transpose_to;
{
// Find the transpose needed to move axes_ to the back.
int j = 0;
for (int i = 0; i < in.ndim(); i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
transpose_to.push_back(i);
}
}
for (auto ax : axes_) {
transpose_to.push_back(ax);
}
}
int start_ax = in.ndim() - axes_.size();
in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);
tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);
auto grad = squeeze(grad_op(in, tan, -1), -1, s);
return {expand_dims(grad, axes_, s)};
} else {
return {grad_op(in, tan, axes_[0])};
}
}
}
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {

View File

@@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_GRADS();
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
@@ -1871,13 +1866,13 @@ class Scatter : public UnaryPrimitive {
const char* name() const override {
switch (reduce_type_) {
case Sum:
return "ScatterSum";
return "Scatter Sum";
case Prod:
return "ScatterProd";
return "Scatter Prod";
case Min:
return "ScatterMin";
return "Scatter Min";
case Max:
return "ScatterMax";
return "Scatter Max";
case None:
return "Scatter";
}
@@ -1910,7 +1905,7 @@ class ScatterAxis : public UnaryPrimitive {
const char* name() const override {
switch (reduce_type_) {
case Sum:
return "ScatterAxisSum";
return "ScatterAxis Sum";
case None:
return "ScatterAxis";
}

View File

@@ -766,7 +766,7 @@ auto mlx_slice_update(
const nb::object& obj,
const ScalarOrArray& v) {
// Can't route to slice update if not slice, tuple, or int
if (src.ndim() == 0 ||
if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
!nb::isinstance<nb::int_>(obj))) {
return std::make_pair(false, src);
@@ -888,7 +888,9 @@ auto mlx_slice_update(
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
if (nb::isinstance<mx::array>(obj)) {
if (nb::isinstance<nb::bool_>(obj)) {
return mx::array(nb::cast<bool>(obj), mx::bool_);
} else if (nb::isinstance<mx::array>(obj)) {
auto mask = nb::cast<mx::array>(obj);
if (mask.dtype() == mx::bool_) {
return mask;
@@ -898,6 +900,11 @@ std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
if (mask.dtype() == nb::dtype<bool>()) {
return nd_array_to_mlx(mask, mx::bool_);
}
} else if (nb::isinstance<nb::list>(obj)) {
auto mask = array_from_list(nb::cast<nb::list>(obj), {});
if (mask.dtype() == mx::bool_) {
return mask;
}
}
return std::nullopt;
}

View File

@@ -1929,8 +1929,27 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(a, anp))
def test_setitem_with_boolean_mask(self):
mask_np = np.zeros((10, 10), dtype=bool)
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
# Python list mask
a = mx.array([1.0, 2.0, 3.0])
mask = [True, False, True]
src = mx.array([5.0, 6.0])
expected = mx.array([5.0, 2.0, 6.0])
a[mask] = src
self.assertTrue(mx.array_equal(a, expected))
# mx.array scalar mask
a = mx.array([1.0, 2.0, 3.0])
mask = mx.array(True)
expected = mx.array([5.0, 5.0, 5.0])
a[mask] = 5.0
self.assertTrue(mx.array_equal(a, expected))
# scalar mask
a = mx.array([1.0, 2.0, 3.0])
mask = True
expected = mx.array([5.0, 5.0, 5.0])
a[mask] = 5.0
self.assertTrue(mx.array_equal(a, expected))
mask_np = np.zeros((1, 10, 10), dtype=bool)
with self.assertRaises(ValueError):

View File

@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
grad_fn(model)
self.assertEqual(model[1].item(), 2.0)
def test_reduce_jvp(self):
a = mx.arange(4)
b = mx.array([3, 2, 1, 0])
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 6)
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 18)
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 3)
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 0)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()

View File

@@ -596,6 +596,19 @@ class TestExportImport(mlx_tests.MLXTestCase):
for y in ys:
self.assertEqual(imported(y)[0].item(), fun(y).item())
def test_export_import_scatter_sum(self):
def fun(x, y, z):
return x.at[y].add(z)
x = mx.array([1, 2, 3])
y = mx.array([0, 0, 1])
z = mx.array([1, 1, 1])
path = os.path.join(self.test_dir, "fn.mlxfn")
mx.export_function(path, fun, x, y, z)
imported = mx.import_function(path)
self.assertTrue(mx.array_equal(imported(x, y, z)[0], fun(x, y, z)))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()