Add a SliceUpdate op and primitive (#850)

* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
This commit is contained in:
Jagrit Digani
2024-03-20 10:39:25 -07:00
committed by GitHub
parent 73a8c090e0
commit cec8661113
21 changed files with 1147 additions and 506 deletions

View File

@@ -9,16 +9,43 @@ namespace mlx::core {
namespace {
void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int idx) {
inline void
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
}
inline void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int64_t offset,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
const std::vector<T>& vec,
size_t nelems,
int idx) {
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
const std::vector<T>& vec,
int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx);
}
std::string type_to_name(const array& a) {
std::string tname;
switch (a.dtype()) {