Files
mlx/mlx/backend/common/copy.h

51 lines
1.2 KiB
C
Raw Normal View History

// Copyright © 2023-2024 Apple Inc.
2023-11-30 11:12:53 -08:00
2023-11-29 10:52:08 -08:00
#pragma once
2025-06-09 22:45:08 +09:00
#include "mlx/backend/common/utils.h"
2023-11-29 10:52:08 -08:00
namespace mlx::core {
enum class CopyType {
// Copy a raw scalar input into the full contiguous output
Scalar,
// Copy the raw input buffer contiguously into a raw output buffer of the same
// size
Vector,
// Copy the full virtual input to the full contiguous output
General,
// Copy the full virtual input to the full virtual output. We assume the
// input and output have the same shape.
GeneralGeneral
};
inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
2025-06-09 22:45:08 +09:00
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
return true;
} else {
out.set_data(
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(mallocfn(out.nbytes()));
return false;
}
}
2023-11-29 10:52:08 -08:00
} // namespace mlx::core