mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
6 Commits
jit-nax
...
193cdcd81a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
193cdcd81a | ||
|
|
d8ceae7b77 | ||
|
|
eff0e31f00 | ||
|
|
6c5785bc2f | ||
|
|
8879ee00eb | ||
|
|
6e762fe2e2 |
@@ -179,8 +179,8 @@ assignments, ``updates`` must provide at least as many elements as there are
|
||||
|
||||
Boolean masks follow NumPy semantics:
|
||||
|
||||
- The mask shape must match the shape of the axes it indexes exactly. No mask
|
||||
broadcasting occurs.
|
||||
- The mask shape must match the shape of the axes it indexes exactly. The only
|
||||
exception is a scalar boolean mask, which broadcasts to the full array.
|
||||
- Any axes not covered by the mask are taken in full.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
@@ -15,19 +15,16 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Alias for better readability.
|
||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||
#define CONV_BACKWARD_INPUT \
|
||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
||||
#define CONV_BACKWARD_WEIGHT \
|
||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||
|
||||
// Custom placeholder representing fallback kernel.
|
||||
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
|
||||
enum ConvBackendType {
|
||||
CONV_FALLBACK,
|
||||
CONV_FORWARD,
|
||||
CONV_BACKWARD_INPUT,
|
||||
CONV_BACKWARD_WEIGHT,
|
||||
};
|
||||
|
||||
struct ConvCacheKey {
|
||||
int device_id;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
fe::DataType_t cudnn_dtype;
|
||||
std::array<int, MAX_NDIM> input_shape;
|
||||
std::array<int, MAX_NDIM> weight_shape;
|
||||
std::array<int, MAX_NDIM> stride;
|
||||
@@ -44,15 +41,13 @@ struct ConvCacheKey {
|
||||
auto& conv_cache() {
|
||||
static LRUBytesKeyCache<
|
||||
ConvCacheKey,
|
||||
std::pair<
|
||||
cudnnBackendDescriptorType_t,
|
||||
std::optional<cudnn_frontend::ExecutionPlan>>>
|
||||
std::pair<ConvBackendType, std::optional<DnnGraph>>>
|
||||
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
auto get_conv_op_settings(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
auto get_conv_settings(
|
||||
ConvBackendType backend_type,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y,
|
||||
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
|
||||
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
||||
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
||||
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
||||
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
||||
int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);
|
||||
int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + padding_hi[i];
|
||||
}
|
||||
return std::make_tuple(
|
||||
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
||||
std::optional<DnnGraph> build_conv_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
ConvBackendType backend_type,
|
||||
Dtype dtype,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y,
|
||||
const SmallVector<int64_t>& stride,
|
||||
const SmallVector<int64_t>& padding_lo,
|
||||
const SmallVector<int64_t>& padding_hi,
|
||||
const SmallVector<int64_t>& dilation) {
|
||||
try {
|
||||
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
: dtype_to_cudnn_type(dtype);
|
||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||
.setDataType(compute_dtype)
|
||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||
.setNDims(stride.size())
|
||||
.setStrides(stride.size(), stride.data())
|
||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||
.setDilation(dilation.size(), dilation.data())
|
||||
.build();
|
||||
const std::vector<int64_t>& stride,
|
||||
const std::vector<int64_t>& padding_lo,
|
||||
const std::vector<int64_t>& padding_hi,
|
||||
const std::vector<int64_t>& dilation) {
|
||||
auto compute_dtype =
|
||||
(dtype == float16 || dtype == bfloat16) ? float32 : dtype;
|
||||
DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype);
|
||||
auto x_ = graph.tensor_nchw("X", 'x', x);
|
||||
auto w_ = graph.tensor_nchw("W", 'w', w);
|
||||
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_cudnn_tensor_nchw('x', x))
|
||||
.setwDesc(build_cudnn_tensor_nchw('w', w))
|
||||
.setyDesc(build_cudnn_tensor_nchw('y', y))
|
||||
.setcDesc(conv_desc)
|
||||
.build();
|
||||
auto set_options = [&](auto& options) {
|
||||
options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))
|
||||
.set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)
|
||||
.set_stride(stride)
|
||||
.set_pre_padding(padding_lo)
|
||||
.set_post_padding(padding_hi)
|
||||
.set_dilation(dilation);
|
||||
};
|
||||
|
||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||
return cudnn_frontend::OperationGraphBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setOperationGraph(ops.size(), ops.data())
|
||||
.build();
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
||||
throw;
|
||||
}
|
||||
std::shared_ptr<fe::graph::Tensor_attributes> y_;
|
||||
if (backend_type == CONV_FORWARD) {
|
||||
auto options = fe::graph::Conv_fprop_attributes();
|
||||
set_options(options);
|
||||
y_ = graph.conv_fprop(x_, w_, options);
|
||||
} else if (backend_type == CONV_BACKWARD_INPUT) {
|
||||
auto options = fe::graph::Conv_dgrad_attributes();
|
||||
set_options(options);
|
||||
y_ = graph.conv_dgrad(x_, w_, options);
|
||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
auto options = fe::graph::Conv_wgrad_attributes();
|
||||
set_options(options);
|
||||
y_ = graph.conv_wgrad(w_, x_, options);
|
||||
}
|
||||
graph.tensor_nchw(y_, 'y', y)->set_output(true);
|
||||
|
||||
if (graph.prepare().is_bad()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
|
||||
if (dtype == float32 && !env::enable_tf32()) {
|
||||
graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});
|
||||
}
|
||||
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||
return graph;
|
||||
}
|
||||
|
||||
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
||||
@@ -181,7 +184,7 @@ array group_transpose(
|
||||
// eval_gpu, with cost of possible redundant copies.
|
||||
std::tuple<array, array, array> prepare_args(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
ConvBackendType backend_type,
|
||||
array in,
|
||||
array wt,
|
||||
array out,
|
||||
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
|
||||
return {std::move(in), std::move(wt), std::move(out)};
|
||||
}
|
||||
|
||||
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
||||
inline std::tuple<array&, array&, array&> dispatch_args(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& in,
|
||||
array& wt,
|
||||
array& out) {
|
||||
switch (backend_type) {
|
||||
case CONV_BACKWARD_INPUT:
|
||||
return {out, wt, in};
|
||||
case CONV_BACKWARD_WEIGHT:
|
||||
return {in, out, wt};
|
||||
default:
|
||||
return {in, wt, out};
|
||||
}
|
||||
}
|
||||
|
||||
// Register inputs and outputs before actually running conv op. Can only be
|
||||
// called once per eval_gpu.
|
||||
void register_args(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
ConvBackendType backend_type,
|
||||
array& in,
|
||||
array& wt,
|
||||
array& intermediate_out,
|
||||
@@ -297,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
get_alignment(wt),
|
||||
get_alignment(out)};
|
||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||
auto& [backend_type, plan] = it->second;
|
||||
if (plan) {
|
||||
// Run cached plan.
|
||||
auto& [backend_type, graph] = it->second;
|
||||
if (graph) {
|
||||
// Run cached graph.
|
||||
std::tie(in, wt, out) =
|
||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||
}
|
||||
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
|
||||
encoder,
|
||||
{
|
||||
{'x', gpu_ptr<void>(in)},
|
||||
{'w', gpu_ptr<void>(wt)},
|
||||
{'y', gpu_ptr<void>(out)},
|
||||
}));
|
||||
} else {
|
||||
// Run fallback kernel.
|
||||
gemm_conv(
|
||||
@@ -327,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
|
||||
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||
// convolution, so we make a best guess and then try.
|
||||
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
||||
SmallVector<ConvBackendType, 2> try_backends;
|
||||
if (flip_) {
|
||||
// When weight is flipped, we assume it is backward input convolution.
|
||||
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||
@@ -345,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
}
|
||||
|
||||
// Try to build op graph.
|
||||
cudnnBackendDescriptorType_t backend_type;
|
||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||
ConvBackendType backend_type;
|
||||
std::optional<DnnGraph> graph;
|
||||
for (auto try_backend : try_backends) {
|
||||
auto [in_copy, wt_copy, out_copy] =
|
||||
auto [x, w, y] =
|
||||
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(
|
||||
try_backend,
|
||||
x,
|
||||
w,
|
||||
@@ -361,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
padding_hi_,
|
||||
kernel_dilation_,
|
||||
input_dilation_);
|
||||
op_graph = build_conv_op_graph(
|
||||
graph = build_conv_graph(
|
||||
encoder,
|
||||
try_backend,
|
||||
dtype,
|
||||
@@ -372,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
dilation);
|
||||
if (op_graph) {
|
||||
if (graph) {
|
||||
backend_type = try_backend;
|
||||
in = std::move(in_copy);
|
||||
wt = std::move(wt_copy);
|
||||
out = std::move(out_copy);
|
||||
in = std::move(x);
|
||||
wt = std::move(w);
|
||||
out = std::move(y);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (op_graph) {
|
||||
// Find a plan for the graph and execute it.
|
||||
auto plan = find_cudnn_plan_from_op_graph(
|
||||
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||
if (plan) {
|
||||
// Setup inputs and outputs.
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||
conv_cache().emplace(
|
||||
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (graph) {
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
|
||||
encoder,
|
||||
{
|
||||
{'x', gpu_ptr<void>(in)},
|
||||
{'w', gpu_ptr<void>(wt)},
|
||||
{'y', gpu_ptr<void>(out)},
|
||||
}));
|
||||
conv_cache().emplace(
|
||||
cache_key, std::make_pair(backend_type, std::move(*graph)));
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fallback kernel for settings not supported by cuDNN.
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cudnn.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -12,10 +13,12 @@ namespace mlx::core {
|
||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
void check_cuda_error(const char* name, CUresult err);
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err);
|
||||
|
||||
// The macro version that prints the command that failed.
|
||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||
|
||||
// Base class for RAII managed CUDA resources.
|
||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||
|
||||
@@ -7,32 +7,26 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Create a cudnn tensor descriptor.
|
||||
template <typename Vec>
|
||||
inline cudnn_frontend::Tensor build_cudnn_tensor(
|
||||
int64_t id,
|
||||
const array& x,
|
||||
const Vec& shape,
|
||||
const Vec& strides) {
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(get_alignment(x))
|
||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||
.build();
|
||||
}
|
||||
#define RETURN_IF_ERROR(cmd) \
|
||||
if (auto ret = cmd; ret.is_bad()) { \
|
||||
return ret; \
|
||||
}
|
||||
|
||||
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
||||
// whether a tensor is contiguous is determined with:
|
||||
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
||||
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
||||
// as strided in cuDNN, and we work around it by normalizing the strides.
|
||||
Strides normalized_strides(const array& x) {
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return x.strides();
|
||||
std::vector<int64_t> normalized_strides(const array& x) {
|
||||
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
||||
if (std::all_of(
|
||||
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
|
||||
strides.back() = 1;
|
||||
return strides;
|
||||
}
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return strides;
|
||||
}
|
||||
Strides strides = x.strides();
|
||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||
if (x.shape(i) == 1) {
|
||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||
@@ -42,7 +36,9 @@ Strides normalized_strides(const array& x) {
|
||||
}
|
||||
|
||||
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||
inline auto nhwc_to_nchw(const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
auto strides = normalized_strides(x);
|
||||
assert(shape.size() >= 3);
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
@@ -51,226 +47,95 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||
return std::make_tuple(std::move(shape), std::move(strides));
|
||||
}
|
||||
|
||||
inline auto nhwc_to_nchw(const array& x) {
|
||||
return nhwc_to_nchw(
|
||||
convert_vector<int64_t>(x.shape()), normalized_strides(x));
|
||||
}
|
||||
|
||||
// Return available engines for a |op_graph|.
|
||||
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = true) {
|
||||
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
|
||||
sources.push_back([](auto& op_graph) {
|
||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||
.build();
|
||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||
});
|
||||
if (use_fallback) {
|
||||
sources.push_back([&backend_type](auto& op_graph) {
|
||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setOperation(backend_type)
|
||||
.build();
|
||||
return fallback.getFallbackList();
|
||||
});
|
||||
}
|
||||
|
||||
auto configs =
|
||||
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
|
||||
.generate_engine_config(op_graph);
|
||||
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||
if (cudnn_frontend::hasNumericalNote<
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||
dtype == float32 && !env::enable_tf32()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return filtered_configs;
|
||||
}
|
||||
|
||||
// Take |engine_configs| and |op_graph| and find a working execution plans
|
||||
// from them.
|
||||
std::optional<cudnn_frontend::ExecutionPlan>
|
||||
find_cudnn_plan_from_engine_configs(
|
||||
cudnnHandle_t handle,
|
||||
const cudnn_frontend::EngineConfigList& engine_configs,
|
||||
const cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto op_graph_tag = op_graph.getTag();
|
||||
for (const auto& config : engine_configs) {
|
||||
try {
|
||||
return cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(handle)
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Prepare workspace and args to execute plan.
|
||||
template <typename F>
|
||||
bool prepare_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs,
|
||||
F&& execute) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder), {workspace_size}, uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
auto args = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace_ptr)
|
||||
.setDataPointers(num_args, data_ptrs)
|
||||
.setUids(num_args, uids)
|
||||
.build();
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
|
||||
fe::error_t DnnGraph::prepare() {
|
||||
RETURN_IF_ERROR(validate());
|
||||
try {
|
||||
RETURN_IF_ERROR(build_operation_graph(handle_));
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
// cuDNN bug: they did not catch all exceptions in the API.
|
||||
return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};
|
||||
}
|
||||
RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));
|
||||
return {};
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
||||
fe::error_t DnnGraph::build() {
|
||||
RETURN_IF_ERROR(check_support(handle_));
|
||||
RETURN_IF_ERROR(build_plans(handle_));
|
||||
return {};
|
||||
}
|
||||
|
||||
fe::error_t DnnGraph::encode_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
std::unordered_map<int64_t, void*> variant_pack) {
|
||||
cudnnSetStream(handle_, encoder.stream());
|
||||
CudaGraph cuda_graph(encoder.device());
|
||||
RETURN_IF_ERROR(populate_cuda_graph(
|
||||
handle_, variant_pack, prepare_workspace(encoder), cuda_graph));
|
||||
encoder.add_graph_node(cuda_graph);
|
||||
return {};
|
||||
}
|
||||
|
||||
fe::error_t DnnGraph::encode_capturing(
|
||||
cu::CommandEncoder& encoder,
|
||||
std::unordered_map<int64_t, void*> variant_pack) {
|
||||
auto* workspace_ptr = prepare_workspace(encoder);
|
||||
auto capture = encoder.capture_context();
|
||||
cudnnSetStream(handle_, encoder.stream());
|
||||
auto ret = execute(handle_, variant_pack, workspace_ptr);
|
||||
if (ret.is_bad()) {
|
||||
capture.discard = true;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {
|
||||
int64_t workspace_size = 0;
|
||||
CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder),
|
||||
{static_cast<int>(workspace_size)},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
return gpu_ptr<void>(workspace);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void DnnGraph::set_tensor_attrs(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x,
|
||||
const std::vector<int64_t>& shape,
|
||||
const std::vector<int64_t>& strides) {
|
||||
tensor->set_uid(uid)
|
||||
.set_alignment(get_alignment(x))
|
||||
.set_data_type(dtype_to_cudnn_type(x.dtype()))
|
||||
.set_dim(shape)
|
||||
.set_stride(strides);
|
||||
}
|
||||
|
||||
void DnnGraph::set_tensor_attrs(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x) {
|
||||
set_tensor_attrs(
|
||||
tensor,
|
||||
uid,
|
||||
x,
|
||||
convert_vector<int64_t>(x.shape()),
|
||||
normalized_strides(x));
|
||||
}
|
||||
|
||||
void DnnGraph::set_tensor_attrs_nchw(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x) {
|
||||
auto [shape, strides] = nhwc_to_nchw(x);
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
set_tensor_attrs(tensor, uid, x, shape, strides);
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
||||
if (x.ndim() == 0) {
|
||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
|
||||
}
|
||||
if (x.ndim() == 1) {
|
||||
int64_t s = x.shape(0);
|
||||
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
|
||||
SmallVector<int64_t, 4> strides = {s, 1, s, s};
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
}
|
||||
if (x.ndim() == 2) {
|
||||
int64_t s =
|
||||
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
|
||||
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
||||
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
}
|
||||
if (x.ndim() == 3 || x.ndim() == 4) {
|
||||
return build_cudnn_tensor_nchw(id, x);
|
||||
}
|
||||
throw std::runtime_error(
|
||||
fmt::format("Unsupported array with {} dims.", x.ndim()));
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
|
||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(scalar_dims.size(), scalar_dims.data())
|
||||
.setStrides(scalar_dims.size(), scalar_dims.data())
|
||||
.setId(id)
|
||||
.setAlignment(16)
|
||||
.setDataType(dtype_to_cudnn_type(dtype))
|
||||
.setByValue(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||
cudnnHandle_t handle,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
||||
if (engine_configs.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
||||
}
|
||||
|
||||
bool encode_cudnn_plan_with_capturing(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs) {
|
||||
return prepare_cudnn_plan(
|
||||
encoder,
|
||||
plan,
|
||||
num_args,
|
||||
uids,
|
||||
data_ptrs,
|
||||
[&](auto handle, auto plan, auto args) {
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
bool encode_cudnn_plan_with_graph_api(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs) {
|
||||
return prepare_cudnn_plan(
|
||||
encoder,
|
||||
plan,
|
||||
num_args,
|
||||
uids,
|
||||
data_ptrs,
|
||||
[&](auto handle, auto plan, auto args) {
|
||||
if (!graph) {
|
||||
graph = CudaGraph(encoder.device());
|
||||
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,25 +2,30 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class CommandEncoder;
|
||||
}
|
||||
|
||||
namespace fe = cudnn_frontend;
|
||||
|
||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||
do { \
|
||||
auto error = cmd; \
|
||||
if (!error.is_good()) { \
|
||||
throw std::runtime_error( \
|
||||
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Return pointer alignment of |x|'s data.
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
@@ -35,8 +40,31 @@ inline uint8_t get_alignment(const array& x) {
|
||||
|
||||
// Convert the type of elements in |vec| to |T|.
|
||||
template <typename T, typename Vec>
|
||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
return SmallVector<T>(vec.begin(), vec.end());
|
||||
inline std::vector<T> convert_vector(const Vec& vec) {
|
||||
return std::vector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
// Map dtype to cudnn data type.
|
||||
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return fe::DataType_t::INT8;
|
||||
case int32:
|
||||
return fe::DataType_t::INT32;
|
||||
case uint8:
|
||||
return fe::DataType_t::UINT8;
|
||||
case float16:
|
||||
return fe::DataType_t::HALF;
|
||||
case bfloat16:
|
||||
return fe::DataType_t::BFLOAT16;
|
||||
case float32:
|
||||
return fe::DataType_t::FLOAT;
|
||||
case float64:
|
||||
return fe::DataType_t::DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
||||
@@ -55,111 +83,89 @@ inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helpers used by get_data_ptrs to get pointers.
|
||||
inline void* get_data_ptr(const array& arr) {
|
||||
return const_cast<void*>(gpu_ptr<void>(arr));
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
||||
inline void* get_data_ptr(T& scalar) {
|
||||
return &scalar;
|
||||
}
|
||||
|
||||
// Return an array filled with data pointers of args.
|
||||
template <typename... Args>
|
||||
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
|
||||
return {get_data_ptr(args)...};
|
||||
}
|
||||
|
||||
// Map dtype to cudnn data type.
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return CUDNN_DATA_INT8;
|
||||
case int32:
|
||||
return CUDNN_DATA_INT32;
|
||||
case uint8:
|
||||
return CUDNN_DATA_UINT8;
|
||||
case float16:
|
||||
return CUDNN_DATA_HALF;
|
||||
case bfloat16:
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
case float32:
|
||||
return CUDNN_DATA_FLOAT;
|
||||
case float64:
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
// Extends cuDNN graph with helpers.
|
||||
class DnnGraph : public fe::graph::Graph {
|
||||
public:
|
||||
DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
|
||||
: handle_(handle) {
|
||||
set_io_data_type(dtype_to_cudnn_type(io_dtype));
|
||||
set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
|
||||
set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
|
||||
}
|
||||
}
|
||||
|
||||
// Create a tensor descriptor from |x|.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
|
||||
// Create a cuDNN tensor description from MLX array |x|.
|
||||
auto& tensor(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
|
||||
int64_t uid,
|
||||
const array& x) {
|
||||
set_tensor_attrs(attrs, uid, x);
|
||||
return attrs;
|
||||
}
|
||||
auto tensor(const char* name, int64_t uid, const array& x) {
|
||||
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
|
||||
tensor(attrs, uid, x);
|
||||
return attrs;
|
||||
}
|
||||
|
||||
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
|
||||
// Create a cuDNN tensor description from MLX array |x|, and transpose it from
|
||||
// NHWC layout to NCHW.
|
||||
auto& tensor_nchw(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
|
||||
int64_t uid,
|
||||
const array& x) {
|
||||
set_tensor_attrs_nchw(attrs, uid, x);
|
||||
return attrs;
|
||||
}
|
||||
auto tensor_nchw(const char* name, int64_t uid, const array& x) {
|
||||
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
|
||||
tensor_nchw(attrs, uid, x);
|
||||
return attrs;
|
||||
}
|
||||
|
||||
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
|
||||
// from NHWC to NCHW.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
|
||||
// Create a cuDNN tensor for scalar.
|
||||
auto scalar(const char* name, int64_t uid, Dtype dtype) {
|
||||
return Graph::tensor(fe::graph::Tensor_attributes()
|
||||
.set_name(name)
|
||||
.set_uid(uid)
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(dtype_to_cudnn_type(dtype)));
|
||||
}
|
||||
|
||||
// Create a 4D scalar tensor descriptor, which is passed by value.
|
||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
|
||||
// Call this before setting notes.
|
||||
fe::error_t prepare();
|
||||
// Call this after setting notes.
|
||||
fe::error_t build();
|
||||
|
||||
// Find a working plan for |op_graph|.
|
||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||
cudnnHandle_t handle,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph);
|
||||
// Add cuDNN graph to CUDA graph, using native CUDA graph API.
|
||||
fe::error_t encode_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
std::unordered_map<int64_t, void*> variant_pack);
|
||||
// Add cuDNN graph to CUDA graph, using stream capture.
|
||||
fe::error_t encode_capturing(
|
||||
cu::CommandEncoder& encoder,
|
||||
std::unordered_map<int64_t, void*> variant_pack);
|
||||
|
||||
// Encode the plan to command buffer by capturing.
|
||||
bool encode_cudnn_plan_with_capturing(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs);
|
||||
private:
|
||||
void* prepare_workspace(cu::CommandEncoder& encoder);
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
// Encode the plan to command buffer by using native graph api of cudnn. If the
|
||||
// |graph| is empty it will be populated, otherwise it will be updated.
|
||||
bool encode_cudnn_plan_with_graph_api(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs);
|
||||
#endif
|
||||
void set_tensor_attrs(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x,
|
||||
const std::vector<int64_t>& shape,
|
||||
const std::vector<int64_t>& strides);
|
||||
void set_tensor_attrs(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x);
|
||||
void set_tensor_attrs_nchw(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x);
|
||||
|
||||
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
|
||||
template <typename... Args>
|
||||
bool encode_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
std::initializer_list<int64_t> uids,
|
||||
Args&... args) {
|
||||
assert(uids.size() == sizeof...(args));
|
||||
auto data_ptrs = get_data_ptrs(args...);
|
||||
return encode_cudnn_plan_with_capturing(
|
||||
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
|
||||
}
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
template <typename... Args>
|
||||
bool encode_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
std::initializer_list<int64_t> uids,
|
||||
Args&... args) {
|
||||
assert(uids.size() == sizeof...(args));
|
||||
auto data_ptrs = get_data_ptrs(args...);
|
||||
return encode_cudnn_plan_with_graph_api(
|
||||
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
|
||||
}
|
||||
#endif
|
||||
cudnnHandle_t handle_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -14,15 +14,6 @@ namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
if (err != CUDNN_STATUS_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
bool use_cuda_graphs() {
|
||||
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
return use_graphs;
|
||||
@@ -96,7 +87,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
return;
|
||||
}
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
@@ -341,21 +332,30 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||
for (const auto& node : nodes) {
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||
if (type != cudaGraphNodeTypeKernel) {
|
||||
if (type == cudaGraphNodeTypeGraph) {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
if (num_nodes > 1) {
|
||||
return false;
|
||||
}
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
return is_graph_updatable(child, cluster_dim_x);
|
||||
} else if (type != cudaGraphNodeTypeKernel) {
|
||||
return false;
|
||||
} else {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only dim.x can be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
return false;
|
||||
}
|
||||
// Only one child node allowed when subgraph uses clusters
|
||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||
return false;
|
||||
}
|
||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||
}
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only dim.x can be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
return false;
|
||||
}
|
||||
// Only one child node allowed when subgraph uses clusters
|
||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||
return false;
|
||||
}
|
||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -371,7 +371,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
}
|
||||
cudaGraphNode_t node;
|
||||
int cluster_dim_x = 0;
|
||||
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
|
||||
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(
|
||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||
|
||||
@@ -10,46 +10,8 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace fe = cudnn_frontend;
|
||||
|
||||
namespace {
|
||||
|
||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||
do { \
|
||||
auto error = cmd; \
|
||||
if (!error.is_good()) { \
|
||||
throw std::runtime_error( \
|
||||
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
std::vector<int64_t> normalized_strides(const array& x) {
|
||||
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
||||
if (std::all_of(
|
||||
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
|
||||
strides.back() = 1;
|
||||
return strides;
|
||||
}
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return strides;
|
||||
}
|
||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||
if (x.shape(i) == 1) {
|
||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
void set_tensor_attrs(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x) {
|
||||
tensor->set_uid(uid)
|
||||
.set_dim({x.shape().begin(), x.shape().end()})
|
||||
.set_stride(normalized_strides(x));
|
||||
}
|
||||
|
||||
array prepare_sdpa_input(const array& x, Stream s) {
|
||||
// SDPA kernel's requirements on inputs:
|
||||
// 1. last dim's stride be 1;
|
||||
@@ -99,7 +61,7 @@ constexpr int QKV_NDIM = 4;
|
||||
|
||||
struct SDPACacheKey {
|
||||
int device_id;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
fe::DataType_t cudnn_dtype;
|
||||
std::array<int, QKV_NDIM> q_shape;
|
||||
std::array<int, QKV_NDIM> k_shape;
|
||||
std::array<int, QKV_NDIM> v_shape;
|
||||
@@ -143,13 +105,13 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||
}
|
||||
|
||||
auto& sdpa_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
|
||||
return cache;
|
||||
}
|
||||
|
||||
auto& sdpa_backward_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
|
||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
|
||||
return cache;
|
||||
}
|
||||
@@ -169,7 +131,7 @@ enum UIDS {
|
||||
D_O,
|
||||
};
|
||||
|
||||
fe::graph::Graph build_sdpa_graph(
|
||||
DnnGraph build_sdpa_graph(
|
||||
cudnnHandle_t handle,
|
||||
const array& q,
|
||||
const array& k,
|
||||
@@ -179,34 +141,15 @@ fe::graph::Graph build_sdpa_graph(
|
||||
bool output_logsumexp,
|
||||
const array& o,
|
||||
const array& stats) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.dtype() == bfloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
DnnGraph graph(handle, q.dtype());
|
||||
|
||||
fe::graph::Graph graph;
|
||||
graph.set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
||||
set_tensor_attrs(q_, Q, q);
|
||||
set_tensor_attrs(k_, K, k);
|
||||
set_tensor_attrs(v_, V, v);
|
||||
|
||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Scale")
|
||||
.set_uid(SCALE)
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
auto q_ = graph.tensor("Q", Q, q);
|
||||
auto k_ = graph.tensor("K", K, k);
|
||||
auto v_ = graph.tensor("V", V, v);
|
||||
|
||||
auto options = fe::graph::SDPA_attributes()
|
||||
.set_name("sdpa_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_attn_scale(graph.scalar("Scale", SCALE, float32))
|
||||
.set_generate_stats(output_logsumexp);
|
||||
if (do_causal) {
|
||||
if (q.shape(2) > k.shape(2)) {
|
||||
@@ -216,31 +159,23 @@ fe::graph::Graph build_sdpa_graph(
|
||||
}
|
||||
}
|
||||
if (mask_arr) {
|
||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||
options.set_bias(bias_);
|
||||
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
|
||||
}
|
||||
|
||||
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||
o_->set_output(true);
|
||||
set_tensor_attrs(o_, O, o);
|
||||
graph.tensor(o_, O, o)->set_output(true);
|
||||
if (output_logsumexp) {
|
||||
stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT);
|
||||
set_tensor_attrs(stats_, STATS, stats);
|
||||
graph.tensor(stats_, STATS, stats)->set_output(true);
|
||||
}
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
||||
CHECK_CUDNN_FE_ERROR(graph.prepare());
|
||||
graph.select_behavior_notes(
|
||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||
return graph;
|
||||
}
|
||||
|
||||
fe::graph::Graph build_sdpa_backward_graph(
|
||||
DnnGraph build_sdpa_backward_graph(
|
||||
cudnnHandle_t handle,
|
||||
const array& q,
|
||||
const array& k,
|
||||
@@ -253,42 +188,18 @@ fe::graph::Graph build_sdpa_backward_graph(
|
||||
array& d_q,
|
||||
array& d_k,
|
||||
array& d_v) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.dtype() == bfloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
DnnGraph graph(handle, q.dtype());
|
||||
|
||||
fe::graph::Graph graph;
|
||||
graph.set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
||||
auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O"));
|
||||
auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O"));
|
||||
auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS"));
|
||||
set_tensor_attrs(q_, Q, q);
|
||||
set_tensor_attrs(k_, K, k);
|
||||
set_tensor_attrs(v_, V, v);
|
||||
set_tensor_attrs(o_, O, o);
|
||||
set_tensor_attrs(d_o_, D_O, d_o);
|
||||
set_tensor_attrs(stats_, STATS, stats);
|
||||
stats_->set_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Scale")
|
||||
.set_uid(SCALE)
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
auto q_ = graph.tensor("Q", Q, q);
|
||||
auto k_ = graph.tensor("K", K, k);
|
||||
auto v_ = graph.tensor("V", V, v);
|
||||
auto o_ = graph.tensor("O", O, o);
|
||||
auto d_o_ = graph.tensor("D_O", D_O, d_o);
|
||||
auto stats_ = graph.tensor("STATS", STATS, stats);
|
||||
|
||||
auto options = fe::graph::SDPA_backward_attributes()
|
||||
.set_name("sdpa_backward_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_attn_scale(scale);
|
||||
.set_attn_scale(graph.scalar("Scale", SCALE, float32));
|
||||
if (do_causal) {
|
||||
if (q.shape(2) > k.shape(2)) {
|
||||
options.set_causal_mask(do_causal);
|
||||
@@ -297,56 +208,22 @@ fe::graph::Graph build_sdpa_backward_graph(
|
||||
}
|
||||
}
|
||||
if (mask_arr) {
|
||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||
options.set_bias(bias_);
|
||||
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
|
||||
}
|
||||
|
||||
auto [d_q_, d_k_, d_v_] =
|
||||
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||
d_q_->set_output(true);
|
||||
d_k_->set_output(true);
|
||||
d_v_->set_output(true);
|
||||
set_tensor_attrs(d_q_, D_Q, d_q);
|
||||
set_tensor_attrs(d_k_, D_K, d_k);
|
||||
set_tensor_attrs(d_v_, D_V, d_v);
|
||||
graph.tensor(d_q_, D_Q, d_q)->set_output(true);
|
||||
graph.tensor(d_k_, D_K, d_k)->set_output(true);
|
||||
graph.tensor(d_v_, D_V, d_v)->set_output(true);
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
||||
CHECK_CUDNN_FE_ERROR(graph.prepare());
|
||||
graph.select_behavior_notes(
|
||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||
return graph;
|
||||
}
|
||||
|
||||
void execute_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnHandle_t handle,
|
||||
fe::graph::Graph& graph,
|
||||
std::unordered_map<int64_t, void*>& variant_pack) {
|
||||
int64_t workspace_size = 0;
|
||||
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder),
|
||||
{static_cast<int>(workspace_size)},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
CudaGraph cuda_graph(encoder.device());
|
||||
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
|
||||
handle, variant_pack, workspace_ptr, cuda_graph));
|
||||
encoder.add_graph_node(cuda_graph);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool supports_sdpa_cudnn(
|
||||
@@ -420,19 +297,19 @@ void sdpa_cudnn(
|
||||
auto& graph = it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{Q, gpu_ptr<void>(q)},
|
||||
{K, gpu_ptr<void>(k)},
|
||||
{V, gpu_ptr<void>(v)},
|
||||
{SCALE, &scale},
|
||||
{O, gpu_ptr<void>(o)}};
|
||||
if (mask_arr) {
|
||||
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
|
||||
}
|
||||
if (output_logsumexp) {
|
||||
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||
}
|
||||
|
||||
execute_graph(encoder, handle, graph, variant_pack);
|
||||
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
|
||||
}
|
||||
|
||||
void sdpa_backward_cudnn(
|
||||
@@ -480,21 +357,21 @@ void sdpa_backward_cudnn(
|
||||
auto& graph = it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{Q, gpu_ptr<void>(q)},
|
||||
{K, gpu_ptr<void>(k)},
|
||||
{V, gpu_ptr<void>(v)},
|
||||
{SCALE, &scale},
|
||||
{O, const_cast<void*>(gpu_ptr<void>(o))},
|
||||
{STATS, const_cast<void*>(gpu_ptr<void>(stats))},
|
||||
{D_O, const_cast<void*>(gpu_ptr<void>(d_o))},
|
||||
{O, gpu_ptr<void>(o)},
|
||||
{STATS, gpu_ptr<void>(stats)},
|
||||
{D_O, gpu_ptr<void>(d_o)},
|
||||
{D_Q, gpu_ptr<void>(d_q)},
|
||||
{D_K, gpu_ptr<void>(d_k)},
|
||||
{D_V, gpu_ptr<void>(d_v)}};
|
||||
if (mask_arr) {
|
||||
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
|
||||
}
|
||||
|
||||
execute_graph(encoder, handle, graph, variant_pack);
|
||||
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
|
||||
}
|
||||
|
||||
// Defined in scaled_dot_product_attention.cu file.
|
||||
|
||||
@@ -32,6 +32,13 @@ void check_cuda_error(const char* name, CUresult err) {
|
||||
}
|
||||
}
|
||||
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
if (err != CUDNN_STATUS_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
switch (dtype) {
|
||||
case bool_:
|
||||
|
||||
@@ -31,8 +31,10 @@ inline T* gpu_ptr(array& arr) {
|
||||
arr.offset());
|
||||
}
|
||||
|
||||
// For const array, keep constness in pointer unless it is untyped.
|
||||
template <typename T>
|
||||
inline const T* gpu_ptr(const array& arr) {
|
||||
inline std::conditional_t<std::is_same_v<T, void>, void*, const T*> gpu_ptr(
|
||||
const array& arr) {
|
||||
return gpu_ptr<T>(const_cast<array&>(arr));
|
||||
}
|
||||
|
||||
|
||||
@@ -16,77 +16,7 @@ INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
# CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
|
||||
CCC="xcrun -sdk macosx metal -x metal"
|
||||
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
declare -a HDRS_LIST=($HDRS)
|
||||
declare -a HDRS_STACK=()
|
||||
declare -a HDRS_SORTED=()
|
||||
|
||||
length=${#HDRS_LIST[@]}
|
||||
|
||||
HDRS_LIST+=(".")
|
||||
|
||||
for ((i=0; i<${length}; i+=2));
|
||||
do
|
||||
|
||||
header="${HDRS_LIST[$i+1]#$SRC_DIR/}"
|
||||
|
||||
str_this="${HDRS_LIST[$i]}"
|
||||
str_next="${HDRS_LIST[$i + 2]}"
|
||||
|
||||
depth_this=${#str_this}
|
||||
depth_next=${#str_next}
|
||||
|
||||
# If we have a dependency then we stack it
|
||||
if [ $depth_next -gt $depth_this ]; then
|
||||
HDRS_STACK=($header ${HDRS_STACK[@]})
|
||||
|
||||
# If we are done with this level
|
||||
else
|
||||
# We add the header to out list
|
||||
HDRS_SORTED+=($header)
|
||||
|
||||
# Pop the stacked up dependencies
|
||||
pop_len=$((depth_this - depth_next))
|
||||
for popped_header in "${HDRS_STACK[@]:0:$pop_len}"
|
||||
do
|
||||
HDRS_SORTED+=($popped_header)
|
||||
done
|
||||
|
||||
HDRS_STACK=(${HDRS_STACK[@]:$pop_len})
|
||||
fi
|
||||
|
||||
done
|
||||
|
||||
HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}")
|
||||
|
||||
CONTENT=$(
|
||||
echo "// Copyright © 2025 Apple Inc."
|
||||
echo ""
|
||||
echo "// Auto generated source for $INPUT_FILE"
|
||||
echo ""
|
||||
|
||||
for header in "${HDRS_SORTED[@]}"
|
||||
do
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
echo "// Contents from \"${header}\""
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
echo ""
|
||||
|
||||
echo "#line 1 \"${header}\""
|
||||
|
||||
grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}"
|
||||
|
||||
echo ""
|
||||
|
||||
done
|
||||
|
||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
||||
)
|
||||
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -382,6 +382,7 @@ struct PrimitiveFactory {
|
||||
SERIALIZE_PRIMITIVE(LogicalOr),
|
||||
SERIALIZE_PRIMITIVE(LogAddExp),
|
||||
SERIALIZE_PRIMITIVE(LogSumExp),
|
||||
SERIALIZE_PRIMITIVE(MaskedScatter),
|
||||
SERIALIZE_PRIMITIVE(Matmul),
|
||||
SERIALIZE_PRIMITIVE(Maximum),
|
||||
SERIALIZE_PRIMITIVE(Minimum),
|
||||
|
||||
@@ -3466,10 +3466,8 @@ array masked_scatter(
|
||||
if (mask.dtype() != bool_) {
|
||||
throw std::invalid_argument("[masked_scatter] The mask has to be boolean.");
|
||||
}
|
||||
if (mask.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[masked_scatter] Scalar masks are not supported.");
|
||||
} else if (mask.ndim() > a.ndim()) {
|
||||
|
||||
if (mask.ndim() > a.ndim()) {
|
||||
throw std::invalid_argument(
|
||||
"[masked_scatter] The mask cannot have more dimensions than the target.");
|
||||
}
|
||||
|
||||
@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> Reduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto in = primals[0];
|
||||
auto s = stream();
|
||||
|
||||
auto grad_op = [&s, reduce_type = reduce_type_](
|
||||
const array& x, const array& tan, int axis) {
|
||||
if (reduce_type == Reduce::Min) {
|
||||
auto idx = argmin(x, axis, true, s);
|
||||
return take_along_axis(tan, idx, axis, s);
|
||||
} else if (reduce_type == Reduce::Max) {
|
||||
auto idx = argmax(x, axis, true, s);
|
||||
return take_along_axis(tan, idx, axis, s);
|
||||
} else {
|
||||
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
|
||||
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
|
||||
auto out = multiply(multiply(p1, p2, s), tan, s);
|
||||
return sum(out, axis, true, s);
|
||||
}
|
||||
};
|
||||
|
||||
auto tan = tangents[0];
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
return {sum(tan, axes_, true, s)};
|
||||
} else {
|
||||
if (axes_.size() > 1) {
|
||||
std::vector<int> transpose_to;
|
||||
{
|
||||
// Find the transpose needed to move axes_ to the back.
|
||||
int j = 0;
|
||||
for (int i = 0; i < in.ndim(); i++) {
|
||||
if (j < axes_.size() && axes_[j] == i) {
|
||||
j++;
|
||||
} else {
|
||||
transpose_to.push_back(i);
|
||||
}
|
||||
}
|
||||
for (auto ax : axes_) {
|
||||
transpose_to.push_back(ax);
|
||||
}
|
||||
}
|
||||
|
||||
int start_ax = in.ndim() - axes_.size();
|
||||
in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);
|
||||
tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);
|
||||
|
||||
auto grad = squeeze(grad_op(in, tan, -1), -1, s);
|
||||
return {expand_dims(grad, axes_, s)};
|
||||
} else {
|
||||
return {grad_op(in, tan, axes_[0])};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
||||
@@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
DEFINE_GRADS();
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
@@ -1871,13 +1866,13 @@ class Scatter : public UnaryPrimitive {
|
||||
const char* name() const override {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return "ScatterSum";
|
||||
return "Scatter Sum";
|
||||
case Prod:
|
||||
return "ScatterProd";
|
||||
return "Scatter Prod";
|
||||
case Min:
|
||||
return "ScatterMin";
|
||||
return "Scatter Min";
|
||||
case Max:
|
||||
return "ScatterMax";
|
||||
return "Scatter Max";
|
||||
case None:
|
||||
return "Scatter";
|
||||
}
|
||||
@@ -1910,7 +1905,7 @@ class ScatterAxis : public UnaryPrimitive {
|
||||
const char* name() const override {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return "ScatterAxisSum";
|
||||
return "ScatterAxis Sum";
|
||||
case None:
|
||||
return "ScatterAxis";
|
||||
}
|
||||
|
||||
@@ -766,7 +766,7 @@ auto mlx_slice_update(
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
// Can't route to slice update if not slice, tuple, or int
|
||||
if (src.ndim() == 0 ||
|
||||
if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
|
||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
|
||||
!nb::isinstance<nb::int_>(obj))) {
|
||||
return std::make_pair(false, src);
|
||||
@@ -888,7 +888,9 @@ auto mlx_slice_update(
|
||||
|
||||
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
||||
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
|
||||
if (nb::isinstance<mx::array>(obj)) {
|
||||
if (nb::isinstance<nb::bool_>(obj)) {
|
||||
return mx::array(nb::cast<bool>(obj), mx::bool_);
|
||||
} else if (nb::isinstance<mx::array>(obj)) {
|
||||
auto mask = nb::cast<mx::array>(obj);
|
||||
if (mask.dtype() == mx::bool_) {
|
||||
return mask;
|
||||
@@ -898,6 +900,11 @@ std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
||||
if (mask.dtype() == nb::dtype<bool>()) {
|
||||
return nd_array_to_mlx(mask, mx::bool_);
|
||||
}
|
||||
} else if (nb::isinstance<nb::list>(obj)) {
|
||||
auto mask = array_from_list(nb::cast<nb::list>(obj), {});
|
||||
if (mask.dtype() == mx::bool_) {
|
||||
return mask;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -1929,8 +1929,27 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
def test_setitem_with_boolean_mask(self):
|
||||
mask_np = np.zeros((10, 10), dtype=bool)
|
||||
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
|
||||
# Python list mask
|
||||
a = mx.array([1.0, 2.0, 3.0])
|
||||
mask = [True, False, True]
|
||||
src = mx.array([5.0, 6.0])
|
||||
expected = mx.array([5.0, 2.0, 6.0])
|
||||
a[mask] = src
|
||||
self.assertTrue(mx.array_equal(a, expected))
|
||||
|
||||
# mx.array scalar mask
|
||||
a = mx.array([1.0, 2.0, 3.0])
|
||||
mask = mx.array(True)
|
||||
expected = mx.array([5.0, 5.0, 5.0])
|
||||
a[mask] = 5.0
|
||||
self.assertTrue(mx.array_equal(a, expected))
|
||||
|
||||
# scalar mask
|
||||
a = mx.array([1.0, 2.0, 3.0])
|
||||
mask = True
|
||||
expected = mx.array([5.0, 5.0, 5.0])
|
||||
a[mask] = 5.0
|
||||
self.assertTrue(mx.array_equal(a, expected))
|
||||
|
||||
mask_np = np.zeros((1, 10, 10), dtype=bool)
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
def test_reduce_jvp(self):
|
||||
a = mx.arange(4)
|
||||
b = mx.array([3, 2, 1, 0])
|
||||
|
||||
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 6)
|
||||
|
||||
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 18)
|
||||
|
||||
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 3)
|
||||
|
||||
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
@@ -596,6 +596,19 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
for y in ys:
|
||||
self.assertEqual(imported(y)[0].item(), fun(y).item())
|
||||
|
||||
def test_export_import_scatter_sum(self):
|
||||
def fun(x, y, z):
|
||||
return x.at[y].add(z)
|
||||
|
||||
x = mx.array([1, 2, 3])
|
||||
y = mx.array([0, 0, 1])
|
||||
z = mx.array([1, 1, 1])
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
mx.export_function(path, fun, x, y, z)
|
||||
|
||||
imported = mx.import_function(path)
|
||||
self.assertTrue(mx.array_equal(imported(x, y, z)[0], fun(x, y, z)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user