mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add contiguous_copy_gpu util for copying array (#2379)
This commit is contained in:
@@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||
|
||||
// Materialize
|
||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
||||
array wt_transpose = contiguous_copy_gpu(wt_view, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
@@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = inputs[0];
|
||||
auto wt = inputs[1];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
copies.push_back(in);
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
|
||||
copy_gpu(wt, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
wt = arr_copy;
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
copies.push_back(wt);
|
||||
}
|
||||
|
||||
// 3D conv
|
||||
|
||||
Reference in New Issue
Block a user