mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-04 01:36:42 +08:00
50 lines
1.2 KiB
C++
50 lines
1.2 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
#include "mlx/backend/gpu/copy.h"
|
|
#include "mlx/primitives.h"
|
|
|
|
#include <cassert>
|
|
|
|
namespace mlx::core {
|
|
|
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
|
bool donated = set_copy_output_data(in, out, ctype);
|
|
if (donated && in.dtype() == out.dtype()) {
|
|
// If the output has the same type as the input then there is nothing to
|
|
// copy, just use the buffer.
|
|
return;
|
|
}
|
|
if (ctype == CopyType::GeneralGeneral) {
|
|
ctype = CopyType::General;
|
|
}
|
|
copy_gpu_inplace(in, out, ctype, s);
|
|
}
|
|
|
|
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
|
}
|
|
|
|
void copy_gpu_inplace(
|
|
const array& in,
|
|
array& out,
|
|
CopyType ctype,
|
|
const Stream& s) {
|
|
assert(in.shape() == out.shape());
|
|
return copy_gpu_inplace(
|
|
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
|
}
|
|
|
|
void copy_gpu_inplace(
|
|
const array& in,
|
|
array& out,
|
|
const Strides& i_strides,
|
|
int64_t i_offset,
|
|
CopyType ctype,
|
|
const Stream& s) {
|
|
assert(in.shape() == out.shape());
|
|
return copy_gpu_inplace(
|
|
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
|
}
|
|
|
|
} // namespace mlx::core
|