[CUDA] Backward convolution (#2431)

This commit is contained in:
Cheng
2025-08-01 09:54:05 +09:00
committed by GitHub
parent 8b25ce62d5
commit 86c6a15571
5 changed files with 315 additions and 112 deletions

View File

@@ -196,6 +196,9 @@ void shared_buffer_reshape(
const Strides& out_strides,
array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));