mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Backward convolution (#2431)
This commit is contained in:
@@ -196,6 +196,9 @@ void shared_buffer_reshape(
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
// Like the swapaxes op but safe to call in eval_gpu.
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
|
||||
Reference in New Issue
Block a user