mlx/mlx/backend/common/copy.h
Jagrit Digani cec8661113
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00

41 lines
975 B
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
enum class CopyType {
// Copy a raw scalar input into the full contiguous output
Scalar,
// Copy the raw input buffer contiguously into a raw output buffer of the same
// size
Vector,
// Copy the full virtual input to the full contiguous output
General,
// Copy the full virtual input to the full virtual output. We assume the
// input and output have the same shape.
GeneralGeneral
};
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core