mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Fix conv grads with groups (#2495)
* Put reshape utils in one file * [CUDA] Fix conv grads with groups * Put the reshape utils in gpu/copy.h
This commit is contained in:
@@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s);
|
||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
||||
|
||||
// Copy data from |in| and transpose to |out|'s shape.
|
||||
void reshape_gpu(const array& in, array& out, Stream s);
|
||||
|
||||
// Like the normal ops but safe to call in eval_gpu.
|
||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
||||
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user