mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add contiguous_copy_gpu util for copying array (#2379)
This commit is contained in:
@@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
@@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu(
|
||||
return x;
|
||||
}
|
||||
copied = true;
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
return x_copy;
|
||||
return contiguous_copy_gpu(x, s);
|
||||
};
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
|
||||
Reference in New Issue
Block a user