[CUDA] Backward convolution (#2431)

This commit is contained in:
Cheng
2025-08-01 09:54:05 +09:00
committed by GitHub
parent 8b25ce62d5
commit 86c6a15571
5 changed files with 315 additions and 112 deletions

View File

@@ -13,7 +13,6 @@
#include <cub/device/device_segmented_sort.cuh>
#include <cassert>
#include <numeric>
namespace mlx::core {
@@ -27,29 +26,6 @@ struct ModOp {
}
};
// We can not use any op in eval, make an utility.
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
// TODO: Share the code with Transpose::eval.
Shape shape(axes.size());
Strides strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
auto flags = in.flags();
if (flags.contiguous) {
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
}
array out(shape, in.dtype(), nullptr, {});
out.copy_shared_buffer(in, strides, flags, in.data_size());
return out;
}
struct OffsetTransform {
int nsort;