mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -52,7 +52,9 @@ void copy_gpu_inplace(
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -80,6 +82,7 @@ void copy_gpu_inplace(
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
}
|
||||
bool dynamic = dynamic_i_offset || dynamic_o_offset;
|
||||
auto& d = metal::device(s.device);
|
||||
int work_per_thread = 1;
|
||||
std::string kernel_name;
|
||||
@@ -107,9 +110,17 @@ void copy_gpu_inplace(
|
||||
if (large) {
|
||||
kernel_name += "large";
|
||||
}
|
||||
if (dynamic) {
|
||||
kernel_name += "_dynamic";
|
||||
if (ctype != CopyType::GeneralGeneral) {
|
||||
throw std::runtime_error(
|
||||
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
|
||||
}
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||
: get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@@ -145,6 +156,18 @@ void copy_gpu_inplace(
|
||||
compute_encoder.set_bytes(ndim, 5);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
}
|
||||
if (dynamic) {
|
||||
if (dynamic_i_offset) {
|
||||
compute_encoder.set_input_array(*dynamic_i_offset, 6);
|
||||
} else {
|
||||
compute_encoder.set_bytes(0ll, 6);
|
||||
}
|
||||
if (dynamic_o_offset) {
|
||||
compute_encoder.set_input_array(*dynamic_o_offset, 7);
|
||||
} else {
|
||||
compute_encoder.set_bytes(0ll, 7);
|
||||
}
|
||||
}
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
if (thread_group_size != 1024) {
|
||||
@@ -179,13 +202,13 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& istride,
|
||||
int64_t ioffset,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
|
@@ -17,13 +17,15 @@ void copy_gpu_inplace(
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& src,
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
@@ -31,8 +33,8 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& istride,
|
||||
int64_t ioffset,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
|
@@ -218,6 +218,38 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::copy();
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg_dynamic", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1large_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg3large_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"ggn4large_" + lib_name, "copy_gg_dynamic", in_type, out_type, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@@ -45,6 +45,12 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@@ -161,3 +161,78 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
idx.y += dst_xstride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_dynamic_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
constant const int64_t& src_offset [[buffer(6)]],
|
||||
constant const int64_t& dst_offset [[buffer(7)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
|
||||
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_dynamic_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int64_t& src_offset [[buffer(6)]],
|
||||
constant const int64_t& dst_offset [[buffer(7)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
|
||||
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_dynamic_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int64_t& src_offset [[buffer(6)]],
|
||||
constant const int64_t& dst_offset [[buffer(7)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
|
||||
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_dynamic(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
constant const int64_t& src_offset [[buffer(6)]],
|
||||
constant const int64_t& dst_offset [[buffer(7)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
src += src_offset;
|
||||
dst += dst_offset;
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
src_shape,
|
||||
src_strides,
|
||||
dst_strides,
|
||||
ndim);
|
||||
if (N == 1) {
|
||||
dst[idx.y] = src[idx.x];
|
||||
return;
|
||||
}
|
||||
IdxT src_xstride = src_strides[ndim - 1];
|
||||
IdxT dst_xstride = dst_strides[ndim - 1];
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[idx.y] = src[idx.x];
|
||||
idx.x += src_xstride;
|
||||
idx.y += dst_xstride;
|
||||
}
|
||||
}
|
||||
|
@@ -4,29 +4,37 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype)
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
|
||||
|
||||
#define instantiate_copy_same(tname, type) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, type, type, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4)
|
||||
#define instantiate_copy_same(tname, type) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) \
|
||||
instantiate_kernel("gg1_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type, int) \
|
||||
instantiate_kernel("gg2_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type, int) \
|
||||
instantiate_kernel("gg3_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type, int) \
|
||||
instantiate_kernel("ggn2_dynamic_copy" #tname, copy_gg_dynamic, type, type, 2, int) \
|
||||
instantiate_kernel("gg1large_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type) \
|
||||
instantiate_kernel("gg2large_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type) \
|
||||
instantiate_kernel("gg3large_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type) \
|
||||
instantiate_kernel("ggn4large_dynamic_copy" #tname, copy_gg_dynamic, type, type, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_same(itname ##itname, itype) \
|
||||
|
@@ -56,6 +56,14 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@@ -43,6 +44,59 @@ void reshape(const array& in, array& out, Stream s) {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Kernel to compute offset here.
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.move_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
}
|
||||
d.add_temporary(offset, s.index);
|
||||
|
||||
auto dtype = indices.dtype();
|
||||
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
|
||||
auto lib = d.get_library(lib_name, [dtype]() {
|
||||
return fmt::format(
|
||||
R"(
|
||||
[[kernel]] void compute_dynamic_offset_{0}(
|
||||
constant const {1}* indices [[buffer(0)]],
|
||||
device int64_t& offset [[buffer(1)]],
|
||||
constant const int64_t* strides [[buffer(2)]],
|
||||
constant const int* axes [[buffer(3)]],
|
||||
constant const int& n_axes [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
int64_t acc = 0;
|
||||
for (int i = 0; i < n_axes; ++i) {{
|
||||
acc += indices[i] * strides[axes[i]];
|
||||
}}
|
||||
offset = acc;
|
||||
}})",
|
||||
type_to_name(dtype),
|
||||
get_type_string(dtype));
|
||||
});
|
||||
auto kernel = d.get_kernel(lib_name, lib);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donate ? offset : indices, 0);
|
||||
compute_encoder.set_output_array(offset, 1);
|
||||
compute_encoder.set_vector_bytes(strides, 2);
|
||||
compute_encoder.set_vector_bytes(axes, 3);
|
||||
int n_axes = axes.size();
|
||||
compute_encoder.set_bytes(n_axes, 4);
|
||||
MTL::Size dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.dispatch_threads(dims, dims);
|
||||
return offset;
|
||||
}
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
@@ -356,6 +410,72 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto s = stream();
|
||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
auto& start_indices = inputs[2];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
move_or_copy(in, out);
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy or donate input to output
|
||||
auto s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
||||
|
||||
auto out_offset =
|
||||
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
@@ -371,13 +491,11 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if materialization is needed
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
|
Reference in New Issue
Block a user