mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
6 Commits
jit-nax
...
193cdcd81a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
193cdcd81a | ||
|
|
d8ceae7b77 | ||
|
|
eff0e31f00 | ||
|
|
6c5785bc2f | ||
|
|
8879ee00eb | ||
|
|
6e762fe2e2 |
@@ -179,8 +179,8 @@ assignments, ``updates`` must provide at least as many elements as there are
|
|||||||
|
|
||||||
Boolean masks follow NumPy semantics:
|
Boolean masks follow NumPy semantics:
|
||||||
|
|
||||||
- The mask shape must match the shape of the axes it indexes exactly. No mask
|
- The mask shape must match the shape of the axes it indexes exactly. The only
|
||||||
broadcasting occurs.
|
exception is a scalar boolean mask, which broadcasts to the full array.
|
||||||
- Any axes not covered by the mask are taken in full.
|
- Any axes not covered by the mask are taken in full.
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|||||||
@@ -15,19 +15,16 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Alias for better readability.
|
enum ConvBackendType {
|
||||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
CONV_FALLBACK,
|
||||||
#define CONV_BACKWARD_INPUT \
|
CONV_FORWARD,
|
||||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
CONV_BACKWARD_INPUT,
|
||||||
#define CONV_BACKWARD_WEIGHT \
|
CONV_BACKWARD_WEIGHT,
|
||||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
};
|
||||||
|
|
||||||
// Custom placeholder representing fallback kernel.
|
|
||||||
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
|
|
||||||
|
|
||||||
struct ConvCacheKey {
|
struct ConvCacheKey {
|
||||||
int device_id;
|
int device_id;
|
||||||
cudnnDataType_t cudnn_dtype;
|
fe::DataType_t cudnn_dtype;
|
||||||
std::array<int, MAX_NDIM> input_shape;
|
std::array<int, MAX_NDIM> input_shape;
|
||||||
std::array<int, MAX_NDIM> weight_shape;
|
std::array<int, MAX_NDIM> weight_shape;
|
||||||
std::array<int, MAX_NDIM> stride;
|
std::array<int, MAX_NDIM> stride;
|
||||||
@@ -44,15 +41,13 @@ struct ConvCacheKey {
|
|||||||
auto& conv_cache() {
|
auto& conv_cache() {
|
||||||
static LRUBytesKeyCache<
|
static LRUBytesKeyCache<
|
||||||
ConvCacheKey,
|
ConvCacheKey,
|
||||||
std::pair<
|
std::pair<ConvBackendType, std::optional<DnnGraph>>>
|
||||||
cudnnBackendDescriptorType_t,
|
|
||||||
std::optional<cudnn_frontend::ExecutionPlan>>>
|
|
||||||
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto get_conv_op_settings(
|
auto get_conv_settings(
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
ConvBackendType backend_type,
|
||||||
array& x,
|
array& x,
|
||||||
array& w,
|
array& w,
|
||||||
array& y,
|
array& y,
|
||||||
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
|
|||||||
for (int i = 0; i < padding_lo.size(); ++i) {
|
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||||
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
||||||
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
||||||
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);
|
||||||
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);
|
||||||
padding_hi[i] = out_size - in_size + padding_hi[i];
|
padding_hi[i] = out_size - in_size + padding_hi[i];
|
||||||
}
|
}
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
std::optional<DnnGraph> build_conv_graph(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
ConvBackendType backend_type,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
array& x,
|
array& x,
|
||||||
array& w,
|
array& w,
|
||||||
array& y,
|
array& y,
|
||||||
const SmallVector<int64_t>& stride,
|
const std::vector<int64_t>& stride,
|
||||||
const SmallVector<int64_t>& padding_lo,
|
const std::vector<int64_t>& padding_lo,
|
||||||
const SmallVector<int64_t>& padding_hi,
|
const std::vector<int64_t>& padding_hi,
|
||||||
const SmallVector<int64_t>& dilation) {
|
const std::vector<int64_t>& dilation) {
|
||||||
try {
|
auto compute_dtype =
|
||||||
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
(dtype == float16 || dtype == bfloat16) ? float32 : dtype;
|
||||||
? CUDNN_DATA_FLOAT
|
DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype);
|
||||||
: dtype_to_cudnn_type(dtype);
|
auto x_ = graph.tensor_nchw("X", 'x', x);
|
||||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
auto w_ = graph.tensor_nchw("W", 'w', w);
|
||||||
.setDataType(compute_dtype)
|
|
||||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
|
||||||
.setNDims(stride.size())
|
|
||||||
.setStrides(stride.size(), stride.data())
|
|
||||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
|
||||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
|
||||||
.setDilation(dilation.size(), dilation.data())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
auto set_options = [&](auto& options) {
|
||||||
.setxDesc(build_cudnn_tensor_nchw('x', x))
|
options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))
|
||||||
.setwDesc(build_cudnn_tensor_nchw('w', w))
|
.set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)
|
||||||
.setyDesc(build_cudnn_tensor_nchw('y', y))
|
.set_stride(stride)
|
||||||
.setcDesc(conv_desc)
|
.set_pre_padding(padding_lo)
|
||||||
.build();
|
.set_post_padding(padding_hi)
|
||||||
|
.set_dilation(dilation);
|
||||||
|
};
|
||||||
|
|
||||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
std::shared_ptr<fe::graph::Tensor_attributes> y_;
|
||||||
return cudnn_frontend::OperationGraphBuilder()
|
if (backend_type == CONV_FORWARD) {
|
||||||
.setHandle(encoder.device().cudnn_handle())
|
auto options = fe::graph::Conv_fprop_attributes();
|
||||||
.setOperationGraph(ops.size(), ops.data())
|
set_options(options);
|
||||||
.build();
|
y_ = graph.conv_fprop(x_, w_, options);
|
||||||
} catch (cudnn_frontend::cudnnException& error) {
|
} else if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
auto options = fe::graph::Conv_dgrad_attributes();
|
||||||
throw;
|
set_options(options);
|
||||||
}
|
y_ = graph.conv_dgrad(x_, w_, options);
|
||||||
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
auto options = fe::graph::Conv_wgrad_attributes();
|
||||||
|
set_options(options);
|
||||||
|
y_ = graph.conv_wgrad(w_, x_, options);
|
||||||
|
}
|
||||||
|
graph.tensor_nchw(y_, 'y', y)->set_output(true);
|
||||||
|
|
||||||
|
if (graph.prepare().is_bad()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
|
||||||
|
if (dtype == float32 && !env::enable_tf32()) {
|
||||||
|
graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});
|
||||||
|
}
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||||
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
||||||
@@ -181,7 +184,7 @@ array group_transpose(
|
|||||||
// eval_gpu, with cost of possible redundant copies.
|
// eval_gpu, with cost of possible redundant copies.
|
||||||
std::tuple<array, array, array> prepare_args(
|
std::tuple<array, array, array> prepare_args(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
ConvBackendType backend_type,
|
||||||
array in,
|
array in,
|
||||||
array wt,
|
array wt,
|
||||||
array out,
|
array out,
|
||||||
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
|
|||||||
return {std::move(in), std::move(wt), std::move(out)};
|
return {std::move(in), std::move(wt), std::move(out)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
|
||||||
inline std::tuple<array&, array&, array&> dispatch_args(
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
array& in,
|
|
||||||
array& wt,
|
|
||||||
array& out) {
|
|
||||||
switch (backend_type) {
|
|
||||||
case CONV_BACKWARD_INPUT:
|
|
||||||
return {out, wt, in};
|
|
||||||
case CONV_BACKWARD_WEIGHT:
|
|
||||||
return {in, out, wt};
|
|
||||||
default:
|
|
||||||
return {in, wt, out};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register inputs and outputs before actually running conv op. Can only be
|
// Register inputs and outputs before actually running conv op. Can only be
|
||||||
// called once per eval_gpu.
|
// called once per eval_gpu.
|
||||||
void register_args(
|
void register_args(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
ConvBackendType backend_type,
|
||||||
array& in,
|
array& in,
|
||||||
array& wt,
|
array& wt,
|
||||||
array& intermediate_out,
|
array& intermediate_out,
|
||||||
@@ -297,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
get_alignment(wt),
|
get_alignment(wt),
|
||||||
get_alignment(out)};
|
get_alignment(out)};
|
||||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
auto& [backend_type, plan] = it->second;
|
auto& [backend_type, graph] = it->second;
|
||||||
if (plan) {
|
if (graph) {
|
||||||
// Run cached plan.
|
// Run cached graph.
|
||||||
std::tie(in, wt, out) =
|
std::tie(in, wt, out) =
|
||||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
|
||||||
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
encoder,
|
||||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
{
|
||||||
}
|
{'x', gpu_ptr<void>(in)},
|
||||||
|
{'w', gpu_ptr<void>(wt)},
|
||||||
|
{'y', gpu_ptr<void>(out)},
|
||||||
|
}));
|
||||||
} else {
|
} else {
|
||||||
// Run fallback kernel.
|
// Run fallback kernel.
|
||||||
gemm_conv(
|
gemm_conv(
|
||||||
@@ -327,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
|
|
||||||
// There is no reliable way to deduce the proper cuDNN backend for the
|
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||||
// convolution, so we make a best guess and then try.
|
// convolution, so we make a best guess and then try.
|
||||||
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
SmallVector<ConvBackendType, 2> try_backends;
|
||||||
if (flip_) {
|
if (flip_) {
|
||||||
// When weight is flipped, we assume it is backward input convolution.
|
// When weight is flipped, we assume it is backward input convolution.
|
||||||
try_backends.push_back(CONV_BACKWARD_INPUT);
|
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||||
@@ -345,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try to build op graph.
|
// Try to build op graph.
|
||||||
cudnnBackendDescriptorType_t backend_type;
|
ConvBackendType backend_type;
|
||||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
std::optional<DnnGraph> graph;
|
||||||
for (auto try_backend : try_backends) {
|
for (auto try_backend : try_backends) {
|
||||||
auto [in_copy, wt_copy, out_copy] =
|
auto [x, w, y] =
|
||||||
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
||||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(
|
||||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
|
||||||
try_backend,
|
try_backend,
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
@@ -361,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
padding_hi_,
|
padding_hi_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_);
|
input_dilation_);
|
||||||
op_graph = build_conv_op_graph(
|
graph = build_conv_graph(
|
||||||
encoder,
|
encoder,
|
||||||
try_backend,
|
try_backend,
|
||||||
dtype,
|
dtype,
|
||||||
@@ -372,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
padding_lo,
|
padding_lo,
|
||||||
padding_hi,
|
padding_hi,
|
||||||
dilation);
|
dilation);
|
||||||
if (op_graph) {
|
if (graph) {
|
||||||
backend_type = try_backend;
|
backend_type = try_backend;
|
||||||
in = std::move(in_copy);
|
in = std::move(x);
|
||||||
wt = std::move(wt_copy);
|
wt = std::move(w);
|
||||||
out = std::move(out_copy);
|
out = std::move(y);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op_graph) {
|
if (graph) {
|
||||||
// Find a plan for the graph and execute it.
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto plan = find_cudnn_plan_from_op_graph(
|
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
|
||||||
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
encoder,
|
||||||
if (plan) {
|
{
|
||||||
// Setup inputs and outputs.
|
{'x', gpu_ptr<void>(in)},
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
{'w', gpu_ptr<void>(wt)},
|
||||||
|
{'y', gpu_ptr<void>(out)},
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
}));
|
||||||
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
conv_cache().emplace(
|
||||||
conv_cache().emplace(
|
cache_key, std::make_pair(backend_type, std::move(*graph)));
|
||||||
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
return;
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use fallback kernel for settings not supported by cuDNN.
|
// Use fallback kernel for settings not supported by cuDNN.
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include <cudnn.h>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -12,10 +13,12 @@ namespace mlx::core {
|
|||||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
void check_cuda_error(const char* name, CUresult err);
|
void check_cuda_error(const char* name, CUresult err);
|
||||||
|
void check_cudnn_error(const char* name, cudnnStatus_t err);
|
||||||
|
|
||||||
// The macro version that prints the command that failed.
|
// The macro version that prints the command that failed.
|
||||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||||
|
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||||
|
|
||||||
// Base class for RAII managed CUDA resources.
|
// Base class for RAII managed CUDA resources.
|
||||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||||
|
|||||||
@@ -7,32 +7,26 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Create a cudnn tensor descriptor.
|
#define RETURN_IF_ERROR(cmd) \
|
||||||
template <typename Vec>
|
if (auto ret = cmd; ret.is_bad()) { \
|
||||||
inline cudnn_frontend::Tensor build_cudnn_tensor(
|
return ret; \
|
||||||
int64_t id,
|
}
|
||||||
const array& x,
|
|
||||||
const Vec& shape,
|
|
||||||
const Vec& strides) {
|
|
||||||
return cudnn_frontend::TensorBuilder()
|
|
||||||
.setDim(shape.size(), shape.data())
|
|
||||||
.setStrides(strides.size(), strides.data())
|
|
||||||
.setId(id)
|
|
||||||
.setAlignment(get_alignment(x))
|
|
||||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
||||||
// whether a tensor is contiguous is determined with:
|
// whether a tensor is contiguous is determined with:
|
||||||
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
||||||
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
||||||
// as strided in cuDNN, and we work around it by normalizing the strides.
|
// as strided in cuDNN, and we work around it by normalizing the strides.
|
||||||
Strides normalized_strides(const array& x) {
|
std::vector<int64_t> normalized_strides(const array& x) {
|
||||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
||||||
return x.strides();
|
if (std::all_of(
|
||||||
|
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
|
||||||
|
strides.back() = 1;
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||||
|
return strides;
|
||||||
}
|
}
|
||||||
Strides strides = x.strides();
|
|
||||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||||
if (x.shape(i) == 1) {
|
if (x.shape(i) == 1) {
|
||||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||||
@@ -42,7 +36,9 @@ Strides normalized_strides(const array& x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the shape and strides after transposing from NHWC to NCHW.
|
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||||
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
inline auto nhwc_to_nchw(const array& x) {
|
||||||
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
|
auto strides = normalized_strides(x);
|
||||||
assert(shape.size() >= 3);
|
assert(shape.size() >= 3);
|
||||||
shape.insert(shape.begin() + 1, shape.back());
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
shape.erase(shape.end() - 1);
|
shape.erase(shape.end() - 1);
|
||||||
@@ -51,226 +47,95 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
|||||||
return std::make_tuple(std::move(shape), std::move(strides));
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline auto nhwc_to_nchw(const array& x) {
|
|
||||||
return nhwc_to_nchw(
|
|
||||||
convert_vector<int64_t>(x.shape()), normalized_strides(x));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return available engines for a |op_graph|.
|
|
||||||
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
Dtype dtype,
|
|
||||||
cudnn_frontend::OperationGraph& op_graph,
|
|
||||||
bool use_fallback = true) {
|
|
||||||
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
|
|
||||||
sources.push_back([](auto& op_graph) {
|
|
||||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
|
||||||
.setOperationGraph(op_graph)
|
|
||||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
|
||||||
.build();
|
|
||||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
|
||||||
});
|
|
||||||
if (use_fallback) {
|
|
||||||
sources.push_back([&backend_type](auto& op_graph) {
|
|
||||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
|
||||||
.setOperationGraph(op_graph)
|
|
||||||
.setOperation(backend_type)
|
|
||||||
.build();
|
|
||||||
return fallback.getFallbackList();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
auto configs =
|
|
||||||
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
|
|
||||||
.generate_engine_config(op_graph);
|
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigList filtered_configs;
|
|
||||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
|
||||||
if (cudnn_frontend::hasNumericalNote<
|
|
||||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
|
||||||
dtype == float32 && !env::enable_tf32()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
return filtered_configs;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take |engine_configs| and |op_graph| and find a working execution plans
|
|
||||||
// from them.
|
|
||||||
std::optional<cudnn_frontend::ExecutionPlan>
|
|
||||||
find_cudnn_plan_from_engine_configs(
|
|
||||||
cudnnHandle_t handle,
|
|
||||||
const cudnn_frontend::EngineConfigList& engine_configs,
|
|
||||||
const cudnn_frontend::OperationGraph& op_graph) {
|
|
||||||
auto op_graph_tag = op_graph.getTag();
|
|
||||||
for (const auto& config : engine_configs) {
|
|
||||||
try {
|
|
||||||
return cudnn_frontend::ExecutionPlanBuilder()
|
|
||||||
.setHandle(handle)
|
|
||||||
.setEngineConfig(config, op_graph_tag)
|
|
||||||
.build();
|
|
||||||
} catch (cudnn_frontend::cudnnException& error) {
|
|
||||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare workspace and args to execute plan.
|
|
||||||
template <typename F>
|
|
||||||
bool prepare_cudnn_plan(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
int num_args,
|
|
||||||
const int64_t* uids,
|
|
||||||
void** data_ptrs,
|
|
||||||
F&& execute) {
|
|
||||||
int workspace_size = plan.getWorkspaceSize();
|
|
||||||
void* workspace_ptr = nullptr;
|
|
||||||
if (workspace_size > 0) {
|
|
||||||
array workspace(
|
|
||||||
cu::malloc_async(workspace_size, encoder), {workspace_size}, uint8);
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
workspace_ptr = gpu_ptr<void>(workspace);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto args = cudnn_frontend::VariantPackBuilder()
|
|
||||||
.setWorkspacePointer(workspace_ptr)
|
|
||||||
.setDataPointers(num_args, data_ptrs)
|
|
||||||
.setUids(num_args, uids)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto handle = encoder.device().cudnn_handle();
|
|
||||||
cudnnSetStream(handle, encoder.stream());
|
|
||||||
|
|
||||||
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
fe::error_t DnnGraph::prepare() {
|
||||||
auto shape = convert_vector<int64_t>(x.shape());
|
RETURN_IF_ERROR(validate());
|
||||||
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
|
try {
|
||||||
|
RETURN_IF_ERROR(build_operation_graph(handle_));
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
// cuDNN bug: they did not catch all exceptions in the API.
|
||||||
|
return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};
|
||||||
|
}
|
||||||
|
RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
fe::error_t DnnGraph::build() {
|
||||||
|
RETURN_IF_ERROR(check_support(handle_));
|
||||||
|
RETURN_IF_ERROR(build_plans(handle_));
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
fe::error_t DnnGraph::encode_graph(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack) {
|
||||||
|
cudnnSetStream(handle_, encoder.stream());
|
||||||
|
CudaGraph cuda_graph(encoder.device());
|
||||||
|
RETURN_IF_ERROR(populate_cuda_graph(
|
||||||
|
handle_, variant_pack, prepare_workspace(encoder), cuda_graph));
|
||||||
|
encoder.add_graph_node(cuda_graph);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
fe::error_t DnnGraph::encode_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack) {
|
||||||
|
auto* workspace_ptr = prepare_workspace(encoder);
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
cudnnSetStream(handle_, encoder.stream());
|
||||||
|
auto ret = execute(handle_, variant_pack, workspace_ptr);
|
||||||
|
if (ret.is_bad()) {
|
||||||
|
capture.discard = true;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {
|
||||||
|
int64_t workspace_size = 0;
|
||||||
|
CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));
|
||||||
|
if (workspace_size > 0) {
|
||||||
|
array workspace(
|
||||||
|
cu::malloc_async(workspace_size, encoder),
|
||||||
|
{static_cast<int>(workspace_size)},
|
||||||
|
uint8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
return gpu_ptr<void>(workspace);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DnnGraph::set_tensor_attrs(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int64_t>& shape,
|
||||||
|
const std::vector<int64_t>& strides) {
|
||||||
|
tensor->set_uid(uid)
|
||||||
|
.set_alignment(get_alignment(x))
|
||||||
|
.set_data_type(dtype_to_cudnn_type(x.dtype()))
|
||||||
|
.set_dim(shape)
|
||||||
|
.set_stride(strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DnnGraph::set_tensor_attrs(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x) {
|
||||||
|
set_tensor_attrs(
|
||||||
|
tensor,
|
||||||
|
uid,
|
||||||
|
x,
|
||||||
|
convert_vector<int64_t>(x.shape()),
|
||||||
|
normalized_strides(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
void DnnGraph::set_tensor_attrs_nchw(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x) {
|
||||||
auto [shape, strides] = nhwc_to_nchw(x);
|
auto [shape, strides] = nhwc_to_nchw(x);
|
||||||
return build_cudnn_tensor(id, x, shape, strides);
|
set_tensor_attrs(tensor, uid, x, shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
|
||||||
if (x.ndim() == 0) {
|
|
||||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
|
||||||
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
|
|
||||||
}
|
|
||||||
if (x.ndim() == 1) {
|
|
||||||
int64_t s = x.shape(0);
|
|
||||||
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
|
|
||||||
SmallVector<int64_t, 4> strides = {s, 1, s, s};
|
|
||||||
return build_cudnn_tensor(id, x, shape, strides);
|
|
||||||
}
|
|
||||||
if (x.ndim() == 2) {
|
|
||||||
int64_t s =
|
|
||||||
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
|
|
||||||
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
|
||||||
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
|
||||||
return build_cudnn_tensor(id, x, shape, strides);
|
|
||||||
}
|
|
||||||
if (x.ndim() == 3 || x.ndim() == 4) {
|
|
||||||
return build_cudnn_tensor_nchw(id, x);
|
|
||||||
}
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("Unsupported array with {} dims.", x.ndim()));
|
|
||||||
}
|
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
|
|
||||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
|
||||||
return cudnn_frontend::TensorBuilder()
|
|
||||||
.setDim(scalar_dims.size(), scalar_dims.data())
|
|
||||||
.setStrides(scalar_dims.size(), scalar_dims.data())
|
|
||||||
.setId(id)
|
|
||||||
.setAlignment(16)
|
|
||||||
.setDataType(dtype_to_cudnn_type(dtype))
|
|
||||||
.setByValue(true)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
|
||||||
cudnnHandle_t handle,
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
Dtype dtype,
|
|
||||||
cudnn_frontend::OperationGraph& op_graph) {
|
|
||||||
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
|
||||||
if (engine_configs.empty()) {
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool encode_cudnn_plan_with_capturing(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
int num_args,
|
|
||||||
const int64_t* uids,
|
|
||||||
void** data_ptrs) {
|
|
||||||
return prepare_cudnn_plan(
|
|
||||||
encoder,
|
|
||||||
plan,
|
|
||||||
num_args,
|
|
||||||
uids,
|
|
||||||
data_ptrs,
|
|
||||||
[&](auto handle, auto plan, auto args) {
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
|
|
||||||
// Discard the captured graph when failed.
|
|
||||||
capture.discard = true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500
|
|
||||||
bool encode_cudnn_plan_with_graph_api(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
CudaGraph& graph,
|
|
||||||
int num_args,
|
|
||||||
const int64_t* uids,
|
|
||||||
void** data_ptrs) {
|
|
||||||
return prepare_cudnn_plan(
|
|
||||||
encoder,
|
|
||||||
plan,
|
|
||||||
num_args,
|
|
||||||
uids,
|
|
||||||
data_ptrs,
|
|
||||||
[&](auto handle, auto plan, auto args) {
|
|
||||||
if (!graph) {
|
|
||||||
graph = CudaGraph(encoder.device());
|
|
||||||
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
|
|
||||||
CUDNN_STATUS_SUCCESS) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
|
|
||||||
CUDNN_STATUS_SUCCESS) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
encoder.add_graph_node(graph);
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,25 +2,30 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <cudnn_frontend.h>
|
#include <cudnn_frontend.h>
|
||||||
#include <cudnn_frontend_find_plan.h>
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <array>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
class CommandEncoder;
|
class CommandEncoder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace fe = cudnn_frontend;
|
||||||
|
|
||||||
|
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||||
|
do { \
|
||||||
|
auto error = cmd; \
|
||||||
|
if (!error.is_good()) { \
|
||||||
|
throw std::runtime_error( \
|
||||||
|
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
// Return pointer alignment of |x|'s data.
|
// Return pointer alignment of |x|'s data.
|
||||||
inline uint8_t get_alignment(const array& x) {
|
inline uint8_t get_alignment(const array& x) {
|
||||||
uint8_t alignment = 1;
|
uint8_t alignment = 1;
|
||||||
@@ -35,8 +40,31 @@ inline uint8_t get_alignment(const array& x) {
|
|||||||
|
|
||||||
// Convert the type of elements in |vec| to |T|.
|
// Convert the type of elements in |vec| to |T|.
|
||||||
template <typename T, typename Vec>
|
template <typename T, typename Vec>
|
||||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
inline std::vector<T> convert_vector(const Vec& vec) {
|
||||||
return SmallVector<T>(vec.begin(), vec.end());
|
return std::vector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map dtype to cudnn data type.
|
||||||
|
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return fe::DataType_t::INT8;
|
||||||
|
case int32:
|
||||||
|
return fe::DataType_t::INT32;
|
||||||
|
case uint8:
|
||||||
|
return fe::DataType_t::UINT8;
|
||||||
|
case float16:
|
||||||
|
return fe::DataType_t::HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return fe::DataType_t::BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return fe::DataType_t::FLOAT;
|
||||||
|
case float64:
|
||||||
|
return fe::DataType_t::DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
||||||
@@ -55,111 +83,89 @@ inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helpers used by get_data_ptrs to get pointers.
|
// Extends cuDNN graph with helpers.
|
||||||
inline void* get_data_ptr(const array& arr) {
|
class DnnGraph : public fe::graph::Graph {
|
||||||
return const_cast<void*>(gpu_ptr<void>(arr));
|
public:
|
||||||
}
|
DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
|
||||||
|
: handle_(handle) {
|
||||||
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
set_io_data_type(dtype_to_cudnn_type(io_dtype));
|
||||||
inline void* get_data_ptr(T& scalar) {
|
set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
|
||||||
return &scalar;
|
set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
|
||||||
}
|
|
||||||
|
|
||||||
// Return an array filled with data pointers of args.
|
|
||||||
template <typename... Args>
|
|
||||||
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
|
|
||||||
return {get_data_ptr(args)...};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map dtype to cudnn data type.
|
|
||||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case int8:
|
|
||||||
return CUDNN_DATA_INT8;
|
|
||||||
case int32:
|
|
||||||
return CUDNN_DATA_INT32;
|
|
||||||
case uint8:
|
|
||||||
return CUDNN_DATA_UINT8;
|
|
||||||
case float16:
|
|
||||||
return CUDNN_DATA_HALF;
|
|
||||||
case bfloat16:
|
|
||||||
return CUDNN_DATA_BFLOAT16;
|
|
||||||
case float32:
|
|
||||||
return CUDNN_DATA_FLOAT;
|
|
||||||
case float64:
|
|
||||||
return CUDNN_DATA_DOUBLE;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Create a tensor descriptor from |x|.
|
// Create a cuDNN tensor description from MLX array |x|.
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
|
auto& tensor(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x) {
|
||||||
|
set_tensor_attrs(attrs, uid, x);
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
auto tensor(const char* name, int64_t uid, const array& x) {
|
||||||
|
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
|
||||||
|
tensor(attrs, uid, x);
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
|
// Create a cuDNN tensor description from MLX array |x|, and transpose it from
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
|
// NHWC layout to NCHW.
|
||||||
|
auto& tensor_nchw(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x) {
|
||||||
|
set_tensor_attrs_nchw(attrs, uid, x);
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
auto tensor_nchw(const char* name, int64_t uid, const array& x) {
|
||||||
|
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
|
||||||
|
tensor_nchw(attrs, uid, x);
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
|
// Create a cuDNN tensor for scalar.
|
||||||
// from NHWC to NCHW.
|
auto scalar(const char* name, int64_t uid, Dtype dtype) {
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
|
return Graph::tensor(fe::graph::Tensor_attributes()
|
||||||
|
.set_name(name)
|
||||||
|
.set_uid(uid)
|
||||||
|
.set_dim({1, 1, 1, 1})
|
||||||
|
.set_stride({1, 1, 1, 1})
|
||||||
|
.set_is_pass_by_value(true)
|
||||||
|
.set_data_type(dtype_to_cudnn_type(dtype)));
|
||||||
|
}
|
||||||
|
|
||||||
// Create a 4D scalar tensor descriptor, which is passed by value.
|
// Call this before setting notes.
|
||||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
|
fe::error_t prepare();
|
||||||
|
// Call this after setting notes.
|
||||||
|
fe::error_t build();
|
||||||
|
|
||||||
// Find a working plan for |op_graph|.
|
// Add cuDNN graph to CUDA graph, using native CUDA graph API.
|
||||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
fe::error_t encode_graph(
|
||||||
cudnnHandle_t handle,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
std::unordered_map<int64_t, void*> variant_pack);
|
||||||
Dtype dtype,
|
// Add cuDNN graph to CUDA graph, using stream capture.
|
||||||
cudnn_frontend::OperationGraph& op_graph);
|
fe::error_t encode_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack);
|
||||||
|
|
||||||
// Encode the plan to command buffer by capturing.
|
private:
|
||||||
bool encode_cudnn_plan_with_capturing(
|
void* prepare_workspace(cu::CommandEncoder& encoder);
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
int num_args,
|
|
||||||
const int64_t* uids,
|
|
||||||
void** data_ptrs);
|
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500
|
void set_tensor_attrs(
|
||||||
// Encode the plan to command buffer by using native graph api of cudnn. If the
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
// |graph| is empty it will be populated, otherwise it will be updated.
|
int64_t uid,
|
||||||
bool encode_cudnn_plan_with_graph_api(
|
const array& x,
|
||||||
cu::CommandEncoder& encoder,
|
const std::vector<int64_t>& shape,
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
const std::vector<int64_t>& strides);
|
||||||
CudaGraph& graph,
|
void set_tensor_attrs(
|
||||||
int num_args,
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
const int64_t* uids,
|
int64_t uid,
|
||||||
void** data_ptrs);
|
const array& x);
|
||||||
#endif
|
void set_tensor_attrs_nchw(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x);
|
||||||
|
|
||||||
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
|
cudnnHandle_t handle_;
|
||||||
template <typename... Args>
|
};
|
||||||
bool encode_cudnn_plan(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
std::initializer_list<int64_t> uids,
|
|
||||||
Args&... args) {
|
|
||||||
assert(uids.size() == sizeof...(args));
|
|
||||||
auto data_ptrs = get_data_ptrs(args...);
|
|
||||||
return encode_cudnn_plan_with_capturing(
|
|
||||||
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500
|
|
||||||
template <typename... Args>
|
|
||||||
bool encode_cudnn_plan(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
CudaGraph& graph,
|
|
||||||
std::initializer_list<int64_t> uids,
|
|
||||||
Args&... args) {
|
|
||||||
assert(uids.size() == sizeof...(args));
|
|
||||||
auto data_ptrs = get_data_ptrs(args...);
|
|
||||||
return encode_cudnn_plan_with_graph_api(
|
|
||||||
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -14,15 +14,6 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
|
||||||
|
|
||||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
|
||||||
if (err != CUDNN_STATUS_SUCCESS) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool use_cuda_graphs() {
|
bool use_cuda_graphs() {
|
||||||
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
return use_graphs;
|
return use_graphs;
|
||||||
@@ -96,7 +87,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
@@ -341,21 +332,30 @@ bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
|||||||
for (const auto& node : nodes) {
|
for (const auto& node : nodes) {
|
||||||
cudaGraphNodeType type;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
if (type != cudaGraphNodeTypeKernel) {
|
if (type == cudaGraphNodeTypeGraph) {
|
||||||
|
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||||
|
if (num_nodes > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cudaGraph_t child;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||||
|
return is_graph_updatable(child, cluster_dim_x);
|
||||||
|
} else if (type != cudaGraphNodeTypeKernel) {
|
||||||
return false;
|
return false;
|
||||||
|
} else {
|
||||||
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
|
// Only dim.x can be greater than 1
|
||||||
|
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Only one child node allowed when subgraph uses clusters
|
||||||
|
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||||
}
|
}
|
||||||
cudaLaunchAttributeValue cluster_dim;
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
|
||||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
|
||||||
// Only dim.x can be greater than 1
|
|
||||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// Only one child node allowed when subgraph uses clusters
|
|
||||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -371,7 +371,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
}
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
int cluster_dim_x = 0;
|
int cluster_dim_x = 0;
|
||||||
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
|
is_graph_updatable_ &= is_graph_updatable(child, cluster_dim_x);
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(
|
insert_graph_dependencies(
|
||||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||||
|
|||||||
@@ -10,46 +10,8 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace fe = cudnn_frontend;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
|
||||||
do { \
|
|
||||||
auto error = cmd; \
|
|
||||||
if (!error.is_good()) { \
|
|
||||||
throw std::runtime_error( \
|
|
||||||
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
std::vector<int64_t> normalized_strides(const array& x) {
|
|
||||||
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
|
||||||
if (std::all_of(
|
|
||||||
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
|
|
||||||
strides.back() = 1;
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_tensor_attrs(
|
|
||||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
|
||||||
int64_t uid,
|
|
||||||
const array& x) {
|
|
||||||
tensor->set_uid(uid)
|
|
||||||
.set_dim({x.shape().begin(), x.shape().end()})
|
|
||||||
.set_stride(normalized_strides(x));
|
|
||||||
}
|
|
||||||
|
|
||||||
array prepare_sdpa_input(const array& x, Stream s) {
|
array prepare_sdpa_input(const array& x, Stream s) {
|
||||||
// SDPA kernel's requirements on inputs:
|
// SDPA kernel's requirements on inputs:
|
||||||
// 1. last dim's stride be 1;
|
// 1. last dim's stride be 1;
|
||||||
@@ -99,7 +61,7 @@ constexpr int QKV_NDIM = 4;
|
|||||||
|
|
||||||
struct SDPACacheKey {
|
struct SDPACacheKey {
|
||||||
int device_id;
|
int device_id;
|
||||||
cudnnDataType_t cudnn_dtype;
|
fe::DataType_t cudnn_dtype;
|
||||||
std::array<int, QKV_NDIM> q_shape;
|
std::array<int, QKV_NDIM> q_shape;
|
||||||
std::array<int, QKV_NDIM> k_shape;
|
std::array<int, QKV_NDIM> k_shape;
|
||||||
std::array<int, QKV_NDIM> v_shape;
|
std::array<int, QKV_NDIM> v_shape;
|
||||||
@@ -143,13 +105,13 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto& sdpa_cache() {
|
auto& sdpa_cache() {
|
||||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
|
||||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
|
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& sdpa_backward_cache() {
|
auto& sdpa_backward_cache() {
|
||||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
|
||||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
|
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
@@ -169,7 +131,7 @@ enum UIDS {
|
|||||||
D_O,
|
D_O,
|
||||||
};
|
};
|
||||||
|
|
||||||
fe::graph::Graph build_sdpa_graph(
|
DnnGraph build_sdpa_graph(
|
||||||
cudnnHandle_t handle,
|
cudnnHandle_t handle,
|
||||||
const array& q,
|
const array& q,
|
||||||
const array& k,
|
const array& k,
|
||||||
@@ -179,34 +141,15 @@ fe::graph::Graph build_sdpa_graph(
|
|||||||
bool output_logsumexp,
|
bool output_logsumexp,
|
||||||
const array& o,
|
const array& o,
|
||||||
const array& stats) {
|
const array& stats) {
|
||||||
auto dtype = fe::DataType_t::HALF;
|
DnnGraph graph(handle, q.dtype());
|
||||||
if (q.dtype() == bfloat16) {
|
|
||||||
dtype = fe::DataType_t::BFLOAT16;
|
|
||||||
}
|
|
||||||
|
|
||||||
fe::graph::Graph graph;
|
auto q_ = graph.tensor("Q", Q, q);
|
||||||
graph.set_io_data_type(dtype)
|
auto k_ = graph.tensor("K", K, k);
|
||||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
auto v_ = graph.tensor("V", V, v);
|
||||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
|
||||||
|
|
||||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
|
||||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
|
||||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
|
||||||
set_tensor_attrs(q_, Q, q);
|
|
||||||
set_tensor_attrs(k_, K, k);
|
|
||||||
set_tensor_attrs(v_, V, v);
|
|
||||||
|
|
||||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
|
||||||
.set_name("Scale")
|
|
||||||
.set_uid(SCALE)
|
|
||||||
.set_dim({1, 1, 1, 1})
|
|
||||||
.set_stride({1, 1, 1, 1})
|
|
||||||
.set_is_pass_by_value(true)
|
|
||||||
.set_data_type(fe::DataType_t::FLOAT));
|
|
||||||
|
|
||||||
auto options = fe::graph::SDPA_attributes()
|
auto options = fe::graph::SDPA_attributes()
|
||||||
.set_name("sdpa_cudnn")
|
.set_name("sdpa_cudnn")
|
||||||
.set_attn_scale(scale)
|
.set_attn_scale(graph.scalar("Scale", SCALE, float32))
|
||||||
.set_generate_stats(output_logsumexp);
|
.set_generate_stats(output_logsumexp);
|
||||||
if (do_causal) {
|
if (do_causal) {
|
||||||
if (q.shape(2) > k.shape(2)) {
|
if (q.shape(2) > k.shape(2)) {
|
||||||
@@ -216,31 +159,23 @@ fe::graph::Graph build_sdpa_graph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (mask_arr) {
|
if (mask_arr) {
|
||||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
|
||||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
|
||||||
options.set_bias(bias_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||||
o_->set_output(true);
|
graph.tensor(o_, O, o)->set_output(true);
|
||||||
set_tensor_attrs(o_, O, o);
|
|
||||||
if (output_logsumexp) {
|
if (output_logsumexp) {
|
||||||
stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT);
|
graph.tensor(stats_, STATS, stats)->set_output(true);
|
||||||
set_tensor_attrs(stats_, STATS, stats);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
CHECK_CUDNN_FE_ERROR(graph.prepare());
|
||||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
|
||||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
|
||||||
graph.select_behavior_notes(
|
graph.select_behavior_notes(
|
||||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
fe::graph::Graph build_sdpa_backward_graph(
|
DnnGraph build_sdpa_backward_graph(
|
||||||
cudnnHandle_t handle,
|
cudnnHandle_t handle,
|
||||||
const array& q,
|
const array& q,
|
||||||
const array& k,
|
const array& k,
|
||||||
@@ -253,42 +188,18 @@ fe::graph::Graph build_sdpa_backward_graph(
|
|||||||
array& d_q,
|
array& d_q,
|
||||||
array& d_k,
|
array& d_k,
|
||||||
array& d_v) {
|
array& d_v) {
|
||||||
auto dtype = fe::DataType_t::HALF;
|
DnnGraph graph(handle, q.dtype());
|
||||||
if (q.dtype() == bfloat16) {
|
|
||||||
dtype = fe::DataType_t::BFLOAT16;
|
|
||||||
}
|
|
||||||
|
|
||||||
fe::graph::Graph graph;
|
auto q_ = graph.tensor("Q", Q, q);
|
||||||
graph.set_io_data_type(dtype)
|
auto k_ = graph.tensor("K", K, k);
|
||||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
auto v_ = graph.tensor("V", V, v);
|
||||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
auto o_ = graph.tensor("O", O, o);
|
||||||
|
auto d_o_ = graph.tensor("D_O", D_O, d_o);
|
||||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
auto stats_ = graph.tensor("STATS", STATS, stats);
|
||||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
|
||||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
|
||||||
auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O"));
|
|
||||||
auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O"));
|
|
||||||
auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS"));
|
|
||||||
set_tensor_attrs(q_, Q, q);
|
|
||||||
set_tensor_attrs(k_, K, k);
|
|
||||||
set_tensor_attrs(v_, V, v);
|
|
||||||
set_tensor_attrs(o_, O, o);
|
|
||||||
set_tensor_attrs(d_o_, D_O, d_o);
|
|
||||||
set_tensor_attrs(stats_, STATS, stats);
|
|
||||||
stats_->set_data_type(fe::DataType_t::FLOAT);
|
|
||||||
|
|
||||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
|
||||||
.set_name("Scale")
|
|
||||||
.set_uid(SCALE)
|
|
||||||
.set_dim({1, 1, 1, 1})
|
|
||||||
.set_stride({1, 1, 1, 1})
|
|
||||||
.set_is_pass_by_value(true)
|
|
||||||
.set_data_type(fe::DataType_t::FLOAT));
|
|
||||||
|
|
||||||
auto options = fe::graph::SDPA_backward_attributes()
|
auto options = fe::graph::SDPA_backward_attributes()
|
||||||
.set_name("sdpa_backward_cudnn")
|
.set_name("sdpa_backward_cudnn")
|
||||||
.set_attn_scale(scale)
|
.set_attn_scale(graph.scalar("Scale", SCALE, float32));
|
||||||
.set_attn_scale(scale);
|
|
||||||
if (do_causal) {
|
if (do_causal) {
|
||||||
if (q.shape(2) > k.shape(2)) {
|
if (q.shape(2) > k.shape(2)) {
|
||||||
options.set_causal_mask(do_causal);
|
options.set_causal_mask(do_causal);
|
||||||
@@ -297,56 +208,22 @@ fe::graph::Graph build_sdpa_backward_graph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (mask_arr) {
|
if (mask_arr) {
|
||||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
|
||||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
|
||||||
options.set_bias(bias_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [d_q_, d_k_, d_v_] =
|
auto [d_q_, d_k_, d_v_] =
|
||||||
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||||
d_q_->set_output(true);
|
graph.tensor(d_q_, D_Q, d_q)->set_output(true);
|
||||||
d_k_->set_output(true);
|
graph.tensor(d_k_, D_K, d_k)->set_output(true);
|
||||||
d_v_->set_output(true);
|
graph.tensor(d_v_, D_V, d_v)->set_output(true);
|
||||||
set_tensor_attrs(d_q_, D_Q, d_q);
|
|
||||||
set_tensor_attrs(d_k_, D_K, d_k);
|
|
||||||
set_tensor_attrs(d_v_, D_V, d_v);
|
|
||||||
|
|
||||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
CHECK_CUDNN_FE_ERROR(graph.prepare());
|
||||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
|
||||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
|
||||||
graph.select_behavior_notes(
|
graph.select_behavior_notes(
|
||||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
void execute_graph(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnnHandle_t handle,
|
|
||||||
fe::graph::Graph& graph,
|
|
||||||
std::unordered_map<int64_t, void*>& variant_pack) {
|
|
||||||
int64_t workspace_size = 0;
|
|
||||||
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
|
|
||||||
void* workspace_ptr = nullptr;
|
|
||||||
if (workspace_size > 0) {
|
|
||||||
array workspace(
|
|
||||||
cu::malloc_async(workspace_size, encoder),
|
|
||||||
{static_cast<int>(workspace_size)},
|
|
||||||
uint8);
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
workspace_ptr = gpu_ptr<void>(workspace);
|
|
||||||
}
|
|
||||||
|
|
||||||
cudnnSetStream(handle, encoder.stream());
|
|
||||||
|
|
||||||
CudaGraph cuda_graph(encoder.device());
|
|
||||||
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
|
|
||||||
handle, variant_pack, workspace_ptr, cuda_graph));
|
|
||||||
encoder.add_graph_node(cuda_graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool supports_sdpa_cudnn(
|
bool supports_sdpa_cudnn(
|
||||||
@@ -420,19 +297,19 @@ void sdpa_cudnn(
|
|||||||
auto& graph = it->second;
|
auto& graph = it->second;
|
||||||
|
|
||||||
std::unordered_map<int64_t, void*> variant_pack{
|
std::unordered_map<int64_t, void*> variant_pack{
|
||||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
{Q, gpu_ptr<void>(q)},
|
||||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
{K, gpu_ptr<void>(k)},
|
||||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
{V, gpu_ptr<void>(v)},
|
||||||
{SCALE, &scale},
|
{SCALE, &scale},
|
||||||
{O, gpu_ptr<void>(o)}};
|
{O, gpu_ptr<void>(o)}};
|
||||||
if (mask_arr) {
|
if (mask_arr) {
|
||||||
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
|
||||||
}
|
}
|
||||||
if (output_logsumexp) {
|
if (output_logsumexp) {
|
||||||
variant_pack[STATS] = gpu_ptr<void>(stats);
|
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||||
}
|
}
|
||||||
|
|
||||||
execute_graph(encoder, handle, graph, variant_pack);
|
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void sdpa_backward_cudnn(
|
void sdpa_backward_cudnn(
|
||||||
@@ -480,21 +357,21 @@ void sdpa_backward_cudnn(
|
|||||||
auto& graph = it->second;
|
auto& graph = it->second;
|
||||||
|
|
||||||
std::unordered_map<int64_t, void*> variant_pack{
|
std::unordered_map<int64_t, void*> variant_pack{
|
||||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
{Q, gpu_ptr<void>(q)},
|
||||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
{K, gpu_ptr<void>(k)},
|
||||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
{V, gpu_ptr<void>(v)},
|
||||||
{SCALE, &scale},
|
{SCALE, &scale},
|
||||||
{O, const_cast<void*>(gpu_ptr<void>(o))},
|
{O, gpu_ptr<void>(o)},
|
||||||
{STATS, const_cast<void*>(gpu_ptr<void>(stats))},
|
{STATS, gpu_ptr<void>(stats)},
|
||||||
{D_O, const_cast<void*>(gpu_ptr<void>(d_o))},
|
{D_O, gpu_ptr<void>(d_o)},
|
||||||
{D_Q, gpu_ptr<void>(d_q)},
|
{D_Q, gpu_ptr<void>(d_q)},
|
||||||
{D_K, gpu_ptr<void>(d_k)},
|
{D_K, gpu_ptr<void>(d_k)},
|
||||||
{D_V, gpu_ptr<void>(d_v)}};
|
{D_V, gpu_ptr<void>(d_v)}};
|
||||||
if (mask_arr) {
|
if (mask_arr) {
|
||||||
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
execute_graph(encoder, handle, graph, variant_pack);
|
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Defined in scaled_dot_product_attention.cu file.
|
// Defined in scaled_dot_product_attention.cu file.
|
||||||
|
|||||||
@@ -32,6 +32,13 @@ void check_cuda_error(const char* name, CUresult err) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||||
|
if (err != CUDNN_STATUS_SUCCESS) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case bool_:
|
case bool_:
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ inline T* gpu_ptr(array& arr) {
|
|||||||
arr.offset());
|
arr.offset());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For const array, keep constness in pointer unless it is untyped.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline const T* gpu_ptr(const array& arr) {
|
inline std::conditional_t<std::is_same_v<T, void>, void*, const T*> gpu_ptr(
|
||||||
|
const array& arr) {
|
||||||
return gpu_ptr<T>(const_cast<array&>(arr));
|
return gpu_ptr<T>(const_cast<array&>(arr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,77 +16,7 @@ INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
|||||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||||
|
|
||||||
mkdir -p "$OUTPUT_DIR"
|
mkdir -p "$OUTPUT_DIR"
|
||||||
# CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||||
|
|
||||||
CCC="xcrun -sdk macosx metal -x metal"
|
|
||||||
|
|
||||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
|
||||||
|
|
||||||
declare -a HDRS_LIST=($HDRS)
|
|
||||||
declare -a HDRS_STACK=()
|
|
||||||
declare -a HDRS_SORTED=()
|
|
||||||
|
|
||||||
length=${#HDRS_LIST[@]}
|
|
||||||
|
|
||||||
HDRS_LIST+=(".")
|
|
||||||
|
|
||||||
for ((i=0; i<${length}; i+=2));
|
|
||||||
do
|
|
||||||
|
|
||||||
header="${HDRS_LIST[$i+1]#$SRC_DIR/}"
|
|
||||||
|
|
||||||
str_this="${HDRS_LIST[$i]}"
|
|
||||||
str_next="${HDRS_LIST[$i + 2]}"
|
|
||||||
|
|
||||||
depth_this=${#str_this}
|
|
||||||
depth_next=${#str_next}
|
|
||||||
|
|
||||||
# If we have a dependency then we stack it
|
|
||||||
if [ $depth_next -gt $depth_this ]; then
|
|
||||||
HDRS_STACK=($header ${HDRS_STACK[@]})
|
|
||||||
|
|
||||||
# If we are done with this level
|
|
||||||
else
|
|
||||||
# We add the header to out list
|
|
||||||
HDRS_SORTED+=($header)
|
|
||||||
|
|
||||||
# Pop the stacked up dependencies
|
|
||||||
pop_len=$((depth_this - depth_next))
|
|
||||||
for popped_header in "${HDRS_STACK[@]:0:$pop_len}"
|
|
||||||
do
|
|
||||||
HDRS_SORTED+=($popped_header)
|
|
||||||
done
|
|
||||||
|
|
||||||
HDRS_STACK=(${HDRS_STACK[@]:$pop_len})
|
|
||||||
fi
|
|
||||||
|
|
||||||
done
|
|
||||||
|
|
||||||
HDRS_SORTED+=("${INPUT_FILE#$SRC_DIR/}")
|
|
||||||
|
|
||||||
CONTENT=$(
|
|
||||||
echo "// Copyright © 2025 Apple Inc."
|
|
||||||
echo ""
|
|
||||||
echo "// Auto generated source for $INPUT_FILE"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
for header in "${HDRS_SORTED[@]}"
|
|
||||||
do
|
|
||||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
|
||||||
echo "// Contents from \"${header}\""
|
|
||||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
echo "#line 1 \"${header}\""
|
|
||||||
|
|
||||||
grep -h -v -G -e "#include \".*.h\"" -e "#pragma once" "${SRC_DIR}/${header}"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "///////////////////////////////////////////////////////////////////////////////"
|
|
||||||
)
|
|
||||||
|
|
||||||
cat << EOF > "$OUTPUT_FILE"
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|||||||
@@ -382,6 +382,7 @@ struct PrimitiveFactory {
|
|||||||
SERIALIZE_PRIMITIVE(LogicalOr),
|
SERIALIZE_PRIMITIVE(LogicalOr),
|
||||||
SERIALIZE_PRIMITIVE(LogAddExp),
|
SERIALIZE_PRIMITIVE(LogAddExp),
|
||||||
SERIALIZE_PRIMITIVE(LogSumExp),
|
SERIALIZE_PRIMITIVE(LogSumExp),
|
||||||
|
SERIALIZE_PRIMITIVE(MaskedScatter),
|
||||||
SERIALIZE_PRIMITIVE(Matmul),
|
SERIALIZE_PRIMITIVE(Matmul),
|
||||||
SERIALIZE_PRIMITIVE(Maximum),
|
SERIALIZE_PRIMITIVE(Maximum),
|
||||||
SERIALIZE_PRIMITIVE(Minimum),
|
SERIALIZE_PRIMITIVE(Minimum),
|
||||||
|
|||||||
@@ -3466,10 +3466,8 @@ array masked_scatter(
|
|||||||
if (mask.dtype() != bool_) {
|
if (mask.dtype() != bool_) {
|
||||||
throw std::invalid_argument("[masked_scatter] The mask has to be boolean.");
|
throw std::invalid_argument("[masked_scatter] The mask has to be boolean.");
|
||||||
}
|
}
|
||||||
if (mask.ndim() == 0) {
|
|
||||||
throw std::invalid_argument(
|
if (mask.ndim() > a.ndim()) {
|
||||||
"[masked_scatter] Scalar masks are not supported.");
|
|
||||||
} else if (mask.ndim() > a.ndim()) {
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[masked_scatter] The mask cannot have more dimensions than the target.");
|
"[masked_scatter] The mask cannot have more dimensions than the target.");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Reduce::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
auto in = primals[0];
|
||||||
|
auto s = stream();
|
||||||
|
|
||||||
|
auto grad_op = [&s, reduce_type = reduce_type_](
|
||||||
|
const array& x, const array& tan, int axis) {
|
||||||
|
if (reduce_type == Reduce::Min) {
|
||||||
|
auto idx = argmin(x, axis, true, s);
|
||||||
|
return take_along_axis(tan, idx, axis, s);
|
||||||
|
} else if (reduce_type == Reduce::Max) {
|
||||||
|
auto idx = argmax(x, axis, true, s);
|
||||||
|
return take_along_axis(tan, idx, axis, s);
|
||||||
|
} else {
|
||||||
|
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
|
||||||
|
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
|
||||||
|
auto out = multiply(multiply(p1, p2, s), tan, s);
|
||||||
|
return sum(out, axis, true, s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto tan = tangents[0];
|
||||||
|
if (reduce_type_ == Reduce::Sum) {
|
||||||
|
return {sum(tan, axes_, true, s)};
|
||||||
|
} else {
|
||||||
|
if (axes_.size() > 1) {
|
||||||
|
std::vector<int> transpose_to;
|
||||||
|
{
|
||||||
|
// Find the transpose needed to move axes_ to the back.
|
||||||
|
int j = 0;
|
||||||
|
for (int i = 0; i < in.ndim(); i++) {
|
||||||
|
if (j < axes_.size() && axes_[j] == i) {
|
||||||
|
j++;
|
||||||
|
} else {
|
||||||
|
transpose_to.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto ax : axes_) {
|
||||||
|
transpose_to.push_back(ax);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int start_ax = in.ndim() - axes_.size();
|
||||||
|
in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);
|
||||||
|
tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);
|
||||||
|
|
||||||
|
auto grad = squeeze(grad_op(in, tan, -1), -1, s);
|
||||||
|
return {expand_dims(grad, axes_, s)};
|
||||||
|
} else {
|
||||||
|
return {grad_op(in, tan, axes_[0])};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
|||||||
@@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS();
|
||||||
std::vector<array> vjp(
|
|
||||||
const std::vector<array>& primals,
|
|
||||||
const std::vector<array>& cotangents,
|
|
||||||
const std::vector<int>& argnums,
|
|
||||||
const std::vector<array>& outputs) override;
|
|
||||||
|
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
@@ -1871,13 +1866,13 @@ class Scatter : public UnaryPrimitive {
|
|||||||
const char* name() const override {
|
const char* name() const override {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return "ScatterSum";
|
return "Scatter Sum";
|
||||||
case Prod:
|
case Prod:
|
||||||
return "ScatterProd";
|
return "Scatter Prod";
|
||||||
case Min:
|
case Min:
|
||||||
return "ScatterMin";
|
return "Scatter Min";
|
||||||
case Max:
|
case Max:
|
||||||
return "ScatterMax";
|
return "Scatter Max";
|
||||||
case None:
|
case None:
|
||||||
return "Scatter";
|
return "Scatter";
|
||||||
}
|
}
|
||||||
@@ -1910,7 +1905,7 @@ class ScatterAxis : public UnaryPrimitive {
|
|||||||
const char* name() const override {
|
const char* name() const override {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return "ScatterAxisSum";
|
return "ScatterAxis Sum";
|
||||||
case None:
|
case None:
|
||||||
return "ScatterAxis";
|
return "ScatterAxis";
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -766,7 +766,7 @@ auto mlx_slice_update(
|
|||||||
const nb::object& obj,
|
const nb::object& obj,
|
||||||
const ScalarOrArray& v) {
|
const ScalarOrArray& v) {
|
||||||
// Can't route to slice update if not slice, tuple, or int
|
// Can't route to slice update if not slice, tuple, or int
|
||||||
if (src.ndim() == 0 ||
|
if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
|
||||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
|
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
|
||||||
!nb::isinstance<nb::int_>(obj))) {
|
!nb::isinstance<nb::int_>(obj))) {
|
||||||
return std::make_pair(false, src);
|
return std::make_pair(false, src);
|
||||||
@@ -888,7 +888,9 @@ auto mlx_slice_update(
|
|||||||
|
|
||||||
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
||||||
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
|
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
|
||||||
if (nb::isinstance<mx::array>(obj)) {
|
if (nb::isinstance<nb::bool_>(obj)) {
|
||||||
|
return mx::array(nb::cast<bool>(obj), mx::bool_);
|
||||||
|
} else if (nb::isinstance<mx::array>(obj)) {
|
||||||
auto mask = nb::cast<mx::array>(obj);
|
auto mask = nb::cast<mx::array>(obj);
|
||||||
if (mask.dtype() == mx::bool_) {
|
if (mask.dtype() == mx::bool_) {
|
||||||
return mask;
|
return mask;
|
||||||
@@ -898,6 +900,11 @@ std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
|
|||||||
if (mask.dtype() == nb::dtype<bool>()) {
|
if (mask.dtype() == nb::dtype<bool>()) {
|
||||||
return nd_array_to_mlx(mask, mx::bool_);
|
return nd_array_to_mlx(mask, mx::bool_);
|
||||||
}
|
}
|
||||||
|
} else if (nb::isinstance<nb::list>(obj)) {
|
||||||
|
auto mask = array_from_list(nb::cast<nb::list>(obj), {});
|
||||||
|
if (mask.dtype() == mx::bool_) {
|
||||||
|
return mask;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1929,8 +1929,27 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.array_equal(a, anp))
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
def test_setitem_with_boolean_mask(self):
|
def test_setitem_with_boolean_mask(self):
|
||||||
mask_np = np.zeros((10, 10), dtype=bool)
|
# Python list mask
|
||||||
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
|
a = mx.array([1.0, 2.0, 3.0])
|
||||||
|
mask = [True, False, True]
|
||||||
|
src = mx.array([5.0, 6.0])
|
||||||
|
expected = mx.array([5.0, 2.0, 6.0])
|
||||||
|
a[mask] = src
|
||||||
|
self.assertTrue(mx.array_equal(a, expected))
|
||||||
|
|
||||||
|
# mx.array scalar mask
|
||||||
|
a = mx.array([1.0, 2.0, 3.0])
|
||||||
|
mask = mx.array(True)
|
||||||
|
expected = mx.array([5.0, 5.0, 5.0])
|
||||||
|
a[mask] = 5.0
|
||||||
|
self.assertTrue(mx.array_equal(a, expected))
|
||||||
|
|
||||||
|
# scalar mask
|
||||||
|
a = mx.array([1.0, 2.0, 3.0])
|
||||||
|
mask = True
|
||||||
|
expected = mx.array([5.0, 5.0, 5.0])
|
||||||
|
a[mask] = 5.0
|
||||||
|
self.assertTrue(mx.array_equal(a, expected))
|
||||||
|
|
||||||
mask_np = np.zeros((1, 10, 10), dtype=bool)
|
mask_np = np.zeros((1, 10, 10), dtype=bool)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
|||||||
@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
grad_fn(model)
|
grad_fn(model)
|
||||||
self.assertEqual(model[1].item(), 2.0)
|
self.assertEqual(model[1].item(), 2.0)
|
||||||
|
|
||||||
|
def test_reduce_jvp(self):
|
||||||
|
a = mx.arange(4)
|
||||||
|
b = mx.array([3, 2, 1, 0])
|
||||||
|
|
||||||
|
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
|
||||||
|
self.assertEqual(jout[0].item(), 6)
|
||||||
|
|
||||||
|
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
|
||||||
|
self.assertEqual(jout[0].item(), 18)
|
||||||
|
|
||||||
|
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
|
||||||
|
self.assertEqual(jout[0].item(), 3)
|
||||||
|
|
||||||
|
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
|
||||||
|
self.assertEqual(jout[0].item(), 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
@@ -596,6 +596,19 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
|||||||
for y in ys:
|
for y in ys:
|
||||||
self.assertEqual(imported(y)[0].item(), fun(y).item())
|
self.assertEqual(imported(y)[0].item(), fun(y).item())
|
||||||
|
|
||||||
|
def test_export_import_scatter_sum(self):
|
||||||
|
def fun(x, y, z):
|
||||||
|
return x.at[y].add(z)
|
||||||
|
|
||||||
|
x = mx.array([1, 2, 3])
|
||||||
|
y = mx.array([0, 0, 1])
|
||||||
|
z = mx.array([1, 1, 1])
|
||||||
|
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||||
|
mx.export_function(path, fun, x, y, z)
|
||||||
|
|
||||||
|
imported = mx.import_function(path)
|
||||||
|
self.assertTrue(mx.array_equal(imported(x, y, z)[0], fun(x, y, z)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user