mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides * Fix uniform buffer indices or copy kernel arguments * Update utils.h * Remove manual unrolling of elem to loc loop * GPU copy updated to handle negative strides * Add slice update primitive
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
@@ -37,15 +37,22 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& strides_in_pre,
|
||||
const std::vector<stride_t>& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(in, out);
|
||||
auto& strides_in = strides[0];
|
||||
auto& strides_out = strides[1];
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
||||
auto& strides_in_ = strides[0];
|
||||
auto& strides_out_ = strides[1];
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
@@ -72,39 +79,44 @@ void copy_gpu_inplace(
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
out_offset *= size_of(out.dtype());
|
||||
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
|
||||
set_array_buffer(compute_encoder, out, out_offset, 1);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
size_t ndim = shape.size();
|
||||
int ndim = shape.size();
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(
|
||||
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int rest = in.size() / (dim0 * dim1);
|
||||
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
int rest = data_size / (dim0 * dim1);
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||
}
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
@@ -120,4 +132,25 @@ void copy_gpu_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,12 +7,34 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& src,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,29 +1,29 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
@@ -31,61 +31,61 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const int& ndim,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
constant const size_t& dst_stride,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
@@ -94,10 +94,10 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
constant const size_t dst_strides[2],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
@@ -106,10 +106,10 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
constant const size_t dst_strides[3],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
@@ -118,11 +118,11 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
constant const size_t dst_strides[DIM],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
@@ -131,12 +131,12 @@ template <typename T, typename U, int DIM>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const size_t* dst_strides,
|
||||
constant const int& ndim,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
@@ -146,70 +146,70 @@ template <typename T, typename U>
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] \
|
||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
constant const size_t dst_strides[dims], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] \
|
||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
constant const size_t& dst_stride, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] \
|
||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
constant const size_t dst_strides[2], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] \
|
||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
constant const size_t dst_strides[3], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
@@ -218,21 +218,21 @@ template <typename T, typename U>
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const int& ndim, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] \
|
||||
[[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const size_t* dst_strides, \
|
||||
constant const int& ndim, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -65,12 +65,18 @@ struct Limits<bool> {
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with generic dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides,
|
||||
device const stride_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
@@ -78,12 +84,13 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
constant const stride_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
@@ -91,52 +98,59 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
// Non templated version to handle arbitrary dims
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM],
|
||||
constant const size_t c_strides[NDIM]) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
int ndim) {
|
||||
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with fixed N dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
|
||||
return elem * stride;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t
|
||||
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t
|
||||
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
METAL_FUNC size_t elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int d = NDIM - 2; d >= 0; --d) {
|
||||
elem /= shape[d + 1];
|
||||
loc += (elem % shape[d]) * strides[d];
|
||||
}
|
||||
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline size_t elem_to_loc_nd(
|
||||
METAL_FUNC size_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t strides[NDIM]) {
|
||||
@@ -148,33 +162,59 @@ inline size_t elem_to_loc_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
||||
return elem * stride;
|
||||
template <int NDIM>
|
||||
METAL_FUNC int64_t elem_to_loc_nd(
|
||||
uint elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const int64_t strides[NDIM]) {
|
||||
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int d = NDIM - 2; d >= 0; --d) {
|
||||
elem /= shape[d + 1];
|
||||
loc += (elem % shape[d]) * strides[d];
|
||||
}
|
||||
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
inline size_t elem_to_loc(
|
||||
template <int NDIM>
|
||||
METAL_FUNC int64_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
constant const int shape[NDIM],
|
||||
constant const int64_t strides[NDIM]) {
|
||||
int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with generic dims
|
||||
|
||||
METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
@@ -198,18 +238,21 @@ inline uint3 elem_to_loc_3_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with fixed N dims
|
||||
|
||||
template <int NDIM>
|
||||
METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
@@ -219,55 +262,26 @@ inline uint2 elem_to_loc_2_nd(
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides);
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<1>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
return (elem % shape[0]) * strides[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<2>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<3>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<4>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[3]) * strides[3];
|
||||
elem /= shape[3];
|
||||
loc += (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM],
|
||||
constant const size_t c_strides[NDIM]) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
|
@@ -206,7 +206,7 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride});
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
@@ -237,8 +237,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride});
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
|
@@ -865,7 +865,73 @@ void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_gpu_inplace(
|
||||
/* const array& in = */ in,
|
||||
/* array& out = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General,
|
||||
/* const Stream& s = */ stream());
|
||||
} else {
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
}
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if materialization is needed
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
|
||||
// Do copy
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_gpu_inplace<int64_t>(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@@ -9,16 +9,43 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
inline void
|
||||
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
}
|
||||
|
||||
inline void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int64_t offset,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
size_t nelems,
|
||||
int idx) {
|
||||
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
int idx) {
|
||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
std::string tname;
|
||||
switch (a.dtype()) {
|
||||
|
Reference in New Issue
Block a user