mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides * Fix uniform buffer indices or copy kernel arguments * Update utils.h * Remove manual unrolling of elem to loc loop * GPU copy updated to handle negative strides * Add slice update primitive
This commit is contained in:
@@ -9,16 +9,43 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
inline void
|
||||
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
}
|
||||
|
||||
inline void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int64_t offset,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
size_t nelems,
|
||||
int idx) {
|
||||
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
int idx) {
|
||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
std::string tname;
|
||||
switch (a.dtype()) {
|
||||
|
||||
Reference in New Issue
Block a user