MLX
 
Loading...
Searching...
No Matches
copy.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include "mlx/array.h"
6
7namespace mlx::core {
8
9enum class CopyType {
10 // Copy a raw scalar input into the full contiguous output
12
13 // Copy the raw input buffer contiguously into a raw output buffer of the same
14 // size
16
17 // Copy the full virtual input to the full contiguous output
19
20 // Copy the full virtual input to the full virtual output. We assume the
21 // input and output have the same shape.
23};
24
25inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
26 if (ctype == CopyType::Vector) {
27 // If the input is donateable, we are doing a vector copy and the types
28 // have the same size, then the input buffer can hold the output.
29 if (in.is_donatable() && in.itemsize() == out.itemsize()) {
30 out.copy_shared_buffer(in);
31 return true;
32 } else {
33 out.set_data(
35 in.data_size(),
36 in.strides(),
37 in.flags());
38 return false;
39 }
40 } else {
42 return false;
43 }
44}
45
46} // namespace mlx::core
Definition array.h:24
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:313
const Strides & strides() const
The strides of the array.
Definition array.h:117
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:278
void copy_shared_buffer(const array &other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:83
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:327
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
bool set_copy_output_data(const array &in, array &out, CopyType ctype)
Definition copy.h:25
@ General
Definition binary.h:16
CopyType
Definition copy.h:9
@ Vector
Definition copy.h:15
@ GeneralGeneral
Definition copy.h:22
@ Scalar
Definition copy.h:11