mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Backward convolution (#2431)
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
#include <cub/device/device_segmented_sort.cuh>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -27,29 +26,6 @@ struct ModOp {
|
||||
}
|
||||
};
|
||||
|
||||
// We can not use any op in eval, make an utility.
|
||||
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
|
||||
std::vector<int> axes(in.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[axis1], axes[axis2]);
|
||||
// TODO: Share the code with Transpose::eval.
|
||||
Shape shape(axes.size());
|
||||
Strides strides(in.ndim());
|
||||
for (size_t ax = 0; ax < axes.size(); ++ax) {
|
||||
shape[ax] = in.shape()[axes[ax]];
|
||||
strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous) {
|
||||
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
}
|
||||
array out(shape, in.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
struct OffsetTransform {
|
||||
int nsort;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user