Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun 2025-01-07 14:02:16 -08:00 committed by GitHub
parent c9c81d0584
commit 516ded618b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 941 additions and 75 deletions

View File

@ -145,6 +145,8 @@ Operations
sign sign
sin sin
sinh sinh
slice
slice_update
softmax softmax
sort sort
split split

View File

@ -35,29 +35,29 @@ class array {
explicit array(const std::complex<float>& val, Dtype dtype = complex64); explicit array(const std::complex<float>& val, Dtype dtype = complex64);
template <typename It> template <typename It>
array( explicit array(
It data, It data,
Shape shape, Shape shape,
Dtype dtype = Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>()); TypeToDtype<typename std::iterator_traits<It>::value_type>());
template <typename T> template <typename T>
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>()); explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
/* Special case so empty lists default to float32. */ /* Special case so empty lists default to float32. */
array(std::initializer_list<float> data); explicit array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */ /* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype); explicit array(std::initializer_list<int> data, Dtype dtype);
template <typename T> template <typename T>
array( explicit array(
std::initializer_list<T> data, std::initializer_list<T> data,
Shape shape, Shape shape,
Dtype dtype = TypeToDtype<T>()); Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */ /* Build an array from a buffer */
array( explicit array(
allocator::Buffer data, allocator::Buffer data,
Shape shape, Shape shape,
Dtype dtype, Dtype dtype,

View File

@ -29,6 +29,35 @@ void reshape(const array& in, array& out) {
} }
} }
int64_t compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes) {
auto compute_offset = [&strides, &axes](const auto* indices) {
int64_t offset = 0;
for (int i = 0; i < axes.size(); ++i) {
offset += indices[i] * strides[axes[i]];
}
return offset;
};
switch (indices.dtype()) {
case int8:
case uint8:
return compute_offset(indices.data<uint8_t>());
case int16:
case uint16:
return compute_offset(indices.data<uint16_t>());
case int32:
case uint32:
return compute_offset(indices.data<uint32_t>());
case int64:
case uint64:
return compute_offset(indices.data<uint64_t>());
default:
throw std::runtime_error("Invalid indices type.");
}
}
void Abs::eval(const std::vector<array>& inputs, array& out) { void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
@ -519,6 +548,54 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
shared_buffer_slice(in, ostrides, data_offset, data_size, out); shared_buffer_slice(in, ostrides, data_offset, data_size, out);
} }
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ i_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
// Copy or move src to dst
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ o_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) { void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (out.size() == 0) { if (out.size() == 0) {
@ -544,12 +621,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_); auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy // Do copy
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace( copy_inplace(
/* const array& src = */ upd, /* const array& src = */ upd,
/* array& dst = */ out, /* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(), /* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides, /* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out_strides, /* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0, /* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset, /* int64_t o_offset = */ data_offset,

View File

@ -52,7 +52,9 @@ void copy_gpu_inplace(
int64_t inp_offset, int64_t inp_offset,
int64_t out_offset, int64_t out_offset,
CopyType ctype, CopyType ctype,
const Stream& s) { const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@ -80,6 +82,7 @@ void copy_gpu_inplace(
} else { } else {
large = out.data_size() > UINT32_MAX; large = out.data_size() > UINT32_MAX;
} }
bool dynamic = dynamic_i_offset || dynamic_o_offset;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
int work_per_thread = 1; int work_per_thread = 1;
std::string kernel_name; std::string kernel_name;
@ -107,9 +110,17 @@ void copy_gpu_inplace(
if (large) { if (large) {
kernel_name += "large"; kernel_name += "large";
} }
if (dynamic) {
kernel_name += "_dynamic";
if (ctype != CopyType::GeneralGeneral) {
throw std::runtime_error(
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
}
}
} }
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, in, out); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
: get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@ -145,6 +156,18 @@ void copy_gpu_inplace(
compute_encoder.set_bytes(ndim, 5); compute_encoder.set_bytes(ndim, 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} }
if (dynamic) {
if (dynamic_i_offset) {
compute_encoder.set_input_array(*dynamic_i_offset, 6);
} else {
compute_encoder.set_bytes(0ll, 6);
}
if (dynamic_o_offset) {
compute_encoder.set_input_array(*dynamic_o_offset, 7);
} else {
compute_encoder.set_bytes(0ll, 7);
}
}
// NB assuming thread_group_size is a power of 2 larger than 32 x 32 // NB assuming thread_group_size is a power of 2 larger than 32 x 32
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
@ -179,13 +202,13 @@ void copy_gpu_inplace(
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
const Strides& istride, const Strides& i_strides,
int64_t ioffset, int64_t i_offset,
CopyType ctype, CopyType ctype,
const Stream& s) { const Stream& s) {
assert(in.shape() == out.shape()); assert(in.shape() == out.shape());
return copy_gpu_inplace( return copy_gpu_inplace(
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s); in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
} }
void fill_gpu(const array& val, array& out, const Stream& s) { void fill_gpu(const array& val, array& out, const Stream& s) {

View File

@ -17,13 +17,15 @@ void copy_gpu_inplace(
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype, CopyType ctype,
const Stream& s); const Stream& s,
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype); void copy_gpu(const array& src, array& out, CopyType ctype);
void copy_gpu_inplace( void copy_gpu_inplace(
const array& src, const array& in,
array& out, array& out,
CopyType ctype, CopyType ctype,
const Stream& s); const Stream& s);
@ -31,8 +33,8 @@ void copy_gpu_inplace(
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
const Strides& istride, const Strides& i_strides,
int64_t ioffset, int64_t i_offset,
CopyType ctype, CopyType ctype,
const Stream& s); const Stream& s);

View File

@ -218,6 +218,38 @@ MTL::ComputePipelineState* get_copy_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_dynamic_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg_dynamic", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1large_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type);
kernel_source += get_template_definition(
"gg3large_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type);
kernel_source += get_template_definition(
"ggn4large_" + lib_name, "copy_gg_dynamic", in_type, out_type, 4);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_softmax_kernel( MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@ -45,6 +45,12 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& in, const array& in,
const array& out); const array& out);
MTL::ComputePipelineState* get_dynamic_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_softmax_kernel( MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@ -161,3 +161,78 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
idx.y += dst_xstride; idx.y += dst_xstride;
} }
} }
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
}
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_gg_dynamic(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
constant const int64_t& src_offset [[buffer(6)]],
constant const int64_t& dst_offset [[buffer(7)]],
uint3 index [[thread_position_in_grid]]) {
src += src_offset;
dst += dst_offset;
auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
dst_strides,
ndim);
if (N == 1) {
dst[idx.y] = src[idx.x];
return;
}
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = src[idx.x];
idx.x += src_xstride;
idx.y += dst_xstride;
}
}

View File

@ -15,7 +15,8 @@
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \ instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \ instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \ instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
#define instantiate_copy_same(tname, type) \ #define instantiate_copy_same(tname, type) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
@ -25,8 +26,15 @@
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \ instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \ instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \ instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
instantiate_kernel("gn4large_copy" #tname, copy_g, type, type, 4) \ instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) instantiate_kernel("gg1_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type, int) \
instantiate_kernel("gg2_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type, int) \
instantiate_kernel("gg3_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type, int) \
instantiate_kernel("ggn2_dynamic_copy" #tname, copy_gg_dynamic, type, type, 2, int) \
instantiate_kernel("gg1large_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type) \
instantiate_kernel("gg2large_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type) \
instantiate_kernel("gg3large_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type) \
instantiate_kernel("ggn4large_dynamic_copy" #tname, copy_gg_dynamic, type, type, 4)
#define instantiate_copy_itype(itname, itype) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_same(itname ##itname, itype) \ instantiate_copy_same(itname ##itname, itype) \

View File

@ -56,6 +56,14 @@ MTL::ComputePipelineState* get_copy_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_dynamic_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_softmax_kernel( MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@ -4,6 +4,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/load.h" #include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
@ -43,6 +44,59 @@ void reshape(const array& in, array& out, Stream s) {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }
} }
array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
Stream s) {
auto& d = metal::device(s.device);
// Kernel to compute offset here.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.move_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
}
d.add_temporary(offset, s.index);
auto dtype = indices.dtype();
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
auto lib = d.get_library(lib_name, [dtype]() {
return fmt::format(
R"(
[[kernel]] void compute_dynamic_offset_{0}(
constant const {1}* indices [[buffer(0)]],
device int64_t& offset [[buffer(1)]],
constant const int64_t* strides [[buffer(2)]],
constant const int* axes [[buffer(3)]],
constant const int& n_axes [[buffer(4)]],
uint index [[thread_position_in_grid]]) {{
int64_t acc = 0;
for (int i = 0; i < n_axes; ++i) {{
acc += indices[i] * strides[axes[i]];
}}
offset = acc;
}})",
type_to_name(dtype),
get_type_string(dtype));
});
auto kernel = d.get_kernel(lib_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donate ? offset : indices, 0);
compute_encoder.set_output_array(offset, 1);
compute_encoder.set_vector_bytes(strides, 2);
compute_encoder.set_vector_bytes(axes, 3);
int n_axes = axes.size();
compute_encoder.set_bytes(n_axes, 4);
MTL::Size dims = MTL::Size(1, 1, 1);
compute_encoder.dispatch_threads(dims, dims);
return offset;
}
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) { void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0); assert(inputs.size() == 0);
@ -356,6 +410,72 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
slice_gpu(in, out, start_indices_, strides_, stream()); slice_gpu(in, out, start_indices_, strides_, stream());
} }
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& start = inputs[1];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto s = stream();
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
}
void DynamicSliceUpdate::eval_gpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
auto& start_indices = inputs[2];
if (upd.size() == 0) {
move_or_copy(in, out);
return;
}
// Copy or donate input to output
auto s = stream();
auto& d = metal::device(s.device);
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
auto out_offset =
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const Shape& data_shape = */ upd.shape(),
/* const Strides& i_strides = */ upd.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
}
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) { void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (out.size() == 0) { if (out.size() == 0) {
@ -371,13 +491,11 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size() auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector ? CopyType::Vector
: CopyType::General; : CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_); auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy // Do copy

View File

@ -48,6 +48,8 @@ NO_CPU_MULTI(CustomTransforms)
NO_CPU_MULTI(Depends) NO_CPU_MULTI(Depends)
NO_CPU(Divide) NO_CPU(Divide)
NO_CPU_MULTI(DivMod) NO_CPU_MULTI(DivMod)
NO_CPU(DynamicSlice)
NO_CPU(DynamicSliceUpdate)
NO_CPU(NumberOfElements) NO_CPU(NumberOfElements)
NO_CPU(Remainder) NO_CPU(Remainder)
NO_CPU_MULTI(Eigh) NO_CPU_MULTI(Eigh)

View File

@ -49,6 +49,8 @@ NO_GPU_MULTI(CustomTransforms)
NO_GPU_MULTI(Depends) NO_GPU_MULTI(Depends)
NO_GPU(Divide) NO_GPU(Divide)
NO_GPU_MULTI(DivMod) NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(NumberOfElements) NO_GPU(NumberOfElements)
NO_GPU(Remainder) NO_GPU(Remainder)
NO_GPU(Equal) NO_GPU(Equal)

View File

@ -253,6 +253,8 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(Depends), SERIALIZE_PRIMITIVE(Depends),
SERIALIZE_PRIMITIVE(Divide), SERIALIZE_PRIMITIVE(Divide),
SERIALIZE_PRIMITIVE(DivMod), SERIALIZE_PRIMITIVE(DivMod),
SERIALIZE_PRIMITIVE(DynamicSlice),
SERIALIZE_PRIMITIVE(DynamicSliceUpdate),
SERIALIZE_PRIMITIVE(Equal, "NaNEqual"), SERIALIZE_PRIMITIVE(Equal, "NaNEqual"),
SERIALIZE_PRIMITIVE(Erf), SERIALIZE_PRIMITIVE(Erf),
SERIALIZE_PRIMITIVE(ErfInv), SERIALIZE_PRIMITIVE(ErfInv),

View File

@ -647,6 +647,52 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
return std::make_pair(has_neg_strides, out_shape); return std::make_pair(has_neg_strides, out_shape);
} }
void normalize_dynamic_slice_inputs(
const array& a,
const array& start,
std::vector<int>& axes,
const std::string prefix) {
if (start.size() > a.ndim()) {
std::ostringstream msg;
msg << prefix << " Invalid number of starting positions for "
<< "array with dimension " << a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if (start.ndim() > 1) {
std::ostringstream msg;
msg << prefix << " array of starting indices "
<< "must be zero or one dimensional but has dimension " << start.ndim()
<< ".";
throw std::invalid_argument(msg.str());
}
if (start.size() != axes.size()) {
std::ostringstream msg;
msg << prefix << " Number of starting indices " << start.size()
<< " does not match number of axes " << axes.size() << ".";
throw std::invalid_argument(msg.str());
}
if (!issubdtype(start.dtype(), integer)) {
std::ostringstream msg;
msg << prefix << " Start indices must be integers, got type "
<< start.dtype() << ".";
throw std::invalid_argument(msg.str());
}
for (auto& ax : axes) {
auto new_ax = ax < 0 ? ax + a.ndim() : ax;
if (new_ax < 0 || new_ax >= a.ndim()) {
std::ostringstream msg;
msg << prefix << " Invalid axis " << ax << " for array with dimension "
<< a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
ax = new_ax;
}
std::set dims(axes.begin(), axes.end());
if (dims.size() != axes.size()) {
throw std::invalid_argument(prefix + " Repeat axes not allowed.");
}
}
} // namespace } // namespace
array slice( array slice(
@ -687,6 +733,38 @@ array slice(
a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s)); a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s));
} }
array slice(
const array& a,
const array& start,
std::vector<int> axes,
Shape slice_size,
StreamOrDevice s /* = {} */) {
normalize_dynamic_slice_inputs(a, start, axes, "[slice]");
// Check the slice_size
if (slice_size.size() != a.ndim()) {
std::ostringstream msg;
msg << "[slice] Invalid slice size for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < a.ndim(); ++i) {
if (slice_size[i] > a.shape(i)) {
std::ostringstream msg;
msg << "[slice] Invalid slice size " << slice_size
<< " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
auto out_shape = slice_size;
return array(
std::move(out_shape),
a.dtype(),
std::make_shared<DynamicSlice>(
to_stream(s), std::move(axes), std::move(slice_size)),
{a, start});
}
/** Update a slice from the source array */ /** Update a slice from the source array */
array slice_update( array slice_update(
const array& src, const array& src,
@ -699,7 +777,7 @@ array slice_update(
if (start.size() != src.ndim() || stop.size() != src.ndim() || if (start.size() != src.ndim() || stop.size() != src.ndim() ||
strides.size() != src.ndim()) { strides.size() != src.ndim()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[slice] Invalid number of indices or strides for " msg << "[slice_update] Invalid number of indices or strides for "
<< "array with dimension " << src.ndim() << "."; << "array with dimension " << src.ndim() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -734,6 +812,36 @@ array slice_update(
src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s); src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);
} }
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
const array& start,
std::vector<int> axes,
StreamOrDevice s /* = {} */) {
normalize_dynamic_slice_inputs(src, start, axes, "[slice_update]");
// Broadcast update with unspecified axes
auto up_shape = update.shape();
auto dim_diff = std::max(src.ndim() - update.ndim(), 0ul);
up_shape.insert(
up_shape.begin(), src.shape().begin(), src.shape().begin() + dim_diff);
for (int d = dim_diff; d < src.ndim(); ++d) {
up_shape[d] = std::min(up_shape[d], src.shape(d));
}
for (auto ax : axes) {
if (ax < dim_diff) {
up_shape[ax] = 1;
}
}
auto upd = broadcast_to(astype(update, src.dtype(), s), up_shape, s);
return array(
src.shape(),
src.dtype(),
std::make_shared<DynamicSliceUpdate>(to_stream(s), std::move(axes)),
{src, upd, start});
}
std::vector<array> split( std::vector<array> split(
const array& a, const array& a,
const Shape& indices, const Shape& indices,

View File

@ -164,11 +164,27 @@ array slice(
Shape stop, Shape stop,
Shape strides, Shape strides,
StreamOrDevice s = {}); StreamOrDevice s = {});
inline array slice(
const array& a,
std::initializer_list<int> start,
Shape stop,
Shape strides,
StreamOrDevice s = {}) {
return slice(a, Shape(start), std::move(stop), std::move(strides), s);
}
/** Slice an array with a stride of 1 in each dimension. */ /** Slice an array with a stride of 1 in each dimension. */
array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {}); array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
/** Update a slice from the source array */ /** Slice an array with dynamic starting indices. */
array slice(
const array& a,
const array& start,
std::vector<int> axes,
Shape slice_size,
StreamOrDevice s = {});
/** Update a slice from the source array. */
array slice_update( array slice_update(
const array& src, const array& src,
const array& update, const array& update,
@ -177,7 +193,7 @@ array slice_update(
Shape strides, Shape strides,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Update a slice from the source array with stride 1 in each dimension */ /** Update a slice from the source array with stride 1 in each dimension. */
array slice_update( array slice_update(
const array& src, const array& src,
const array& update, const array& update,
@ -185,6 +201,14 @@ array slice_update(
Shape stop, Shape stop,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Update a slice from the source array with dynamic starting indices. */
array slice_update(
const array& src,
const array& update,
const array& start,
std::vector<int> axes,
StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */ /** Split an array into sub-arrays along a given axis. */
std::vector<array> std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); split(const array& a, int num_splits, int axis, StreamOrDevice s = {});

View File

@ -3715,22 +3715,18 @@ std::vector<array> SliceUpdate::vjp(
for (int num : argnums) { for (int num : argnums) {
// Vjp for source // Vjp for source
if (num == 0) { if (num == 0) {
auto grad = slice_update( vjps.push_back(slice_update(
cotan, cotan,
zeros_like(upd, stream()), zeros_like(upd, stream()),
start_indices_, start_indices_,
end_indices_, end_indices_,
strides_, strides_,
stream()); stream()));
vjps.push_back(grad);
} }
// Vjp fpr updates // Vjp fpr updates
else { else {
auto grad = vjps.push_back(
slice(cotan, start_indices_, end_indices_, strides_, stream()); slice(cotan, start_indices_, end_indices_, strides_, stream()));
vjps.push_back(grad);
} }
} }
@ -3753,12 +3749,153 @@ std::vector<array> SliceUpdate::jvp(
} }
bool SliceUpdate::is_equivalent(const Primitive& other) const { bool SliceUpdate::is_equivalent(const Primitive& other) const {
const SliceUpdate& s_other = static_cast<const SliceUpdate&>(other); const auto& s_other = static_cast<const SliceUpdate&>(other);
return ( return (
start_indices_ == s_other.start_indices_ && start_indices_ == s_other.start_indices_ &&
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
} }
std::pair<std::vector<array>, std::vector<int>> DynamicSlice::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto& in = inputs[0];
auto& start = inputs[1];
auto vax = axes[0];
if (axes[1] >= 0) {
throw std::invalid_argument(
"[DynamicSlice::vmap] vmap over start indices not yet supported.");
}
auto slice_size = slice_size_;
auto slice_axes = axes_;
if (vax >= 0) {
for (auto& ax : slice_axes) {
if (ax >= vax) {
ax++;
}
}
slice_size.insert(slice_size.begin() + vax, in.shape(vax));
}
return {
{slice(
in, start, std::move(slice_axes), std::move(slice_size), stream())},
{vax}};
}
std::vector<array> DynamicSlice::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
if (argnums[0] == 1 || argnums.size() > 1) {
throw std::invalid_argument(
"[DynamicSlice::vjp] Not supported for start indices.");
}
auto out = zeros_like(primals[0], stream());
return {slice_update(out, cotangents[0], primals[1], axes_, stream())};
}
std::vector<array> DynamicSlice::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {slice(tangents[0], primals[1], axes_, slice_size_, stream())};
}
bool DynamicSlice::is_equivalent(const Primitive& other) const {
const auto& s_other = static_cast<const DynamicSlice&>(other);
return (axes_ == s_other.axes_ && slice_size_ == s_other.slice_size_);
}
std::vector<Shape> DynamicSlice::output_shapes(const std::vector<array>&) {
return {slice_size_};
}
std::pair<std::vector<array>, std::vector<int>> DynamicSliceUpdate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto src = inputs[0];
auto upd = inputs[1];
auto& start = inputs[2];
auto src_ax = axes[0];
auto upd_ax = axes[1];
if (axes[2] >= 0) {
throw std::runtime_error(
"[DynamicSliceUpdate::vmap] vmap over start indices not yet supported.");
}
// No vmapping needed
if (src_ax == -1 && upd_ax == -1) {
return {{slice_update(src, upd, start, axes_, stream())}, {-1}};
}
// Broadcast src
if (src_ax == -1) {
src = expand_dims(src, upd_ax, stream());
auto shape = src.shape();
shape[upd_ax] = upd.shape(upd_ax);
src = broadcast_to(src, shape, stream());
src_ax = upd_ax;
}
// Broadcast upd
if (upd_ax == -1) {
upd = expand_dims(upd, src_ax, stream());
upd_ax = src_ax;
}
if (src_ax != upd_ax) {
upd = moveaxis(upd, upd_ax, src_ax, stream());
}
auto slice_axes = axes_;
for (auto& ax : slice_axes) {
if (ax >= src_ax) {
ax++;
}
}
return {
{slice_update(src, upd, start, std::move(slice_axes), stream())},
{src_ax}};
}
std::vector<array> DynamicSliceUpdate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
auto& cotan = cotangents[0];
auto& upd = primals[1];
auto& start = primals[2];
std::vector<array> vjps;
for (int num : argnums) {
if (num == 0) {
// Vjp for source
vjps.push_back(slice_update(
cotan, zeros_like(upd, stream()), start, axes_, stream()));
} else if (num == 1) {
// Vjp fpr updates
vjps.push_back(slice(cotan, start, axes_, upd.shape(), stream()));
} else {
throw std::invalid_argument(
"[DynamicSliceUpdate::vjp] Not supported for start indices");
}
}
return vjps;
}
std::vector<array> DynamicSliceUpdate::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {slice_update(tangents[0], tangents[1], primals[2], axes_, stream())};
}
bool DynamicSliceUpdate::is_equivalent(const Primitive& other) const {
const auto& s_other = static_cast<const DynamicSliceUpdate&>(other);
return axes_ == s_other.axes_;
}
std::pair<std::vector<array>, std::vector<int>> Softmax::vmap( std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {

View File

@ -2057,6 +2057,51 @@ class SliceUpdate : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class DynamicSlice : public UnaryPrimitive {
public:
explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)
: UnaryPrimitive(stream),
axes_(std::move(axes)),
slice_size_(std::move(slice_size)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(DynamicSlice)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_pair(axes_, slice_size_);
}
private:
std::vector<int> axes_;
Shape slice_size_;
};
class DynamicSliceUpdate : public UnaryPrimitive {
public:
explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(DynamicSliceUpdate)
bool is_equivalent(const Primitive& other) const override;
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return axes_;
}
private:
std::vector<int> axes_;
};
class Softmax : public UnaryPrimitive { class Softmax : public UnaryPrimitive {
public: public:
explicit Softmax(Stream stream, bool precise) explicit Softmax(Stream stream, bool precise)

View File

@ -764,7 +764,7 @@ auto mlx_slice_update(
const mx::array& src, const mx::array& src,
const nb::object& obj, const nb::object& obj,
const ScalarOrArray& v) { const ScalarOrArray& v) {
// Can't route to slice update if not slice or tuple // Can't route to slice update if not slice, tuple, or int
if (src.ndim() == 0 || if (src.ndim() == 0 ||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) && (!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
!nb::isinstance<nb::int_>(obj))) { !nb::isinstance<nb::int_>(obj))) {
@ -845,20 +845,14 @@ auto mlx_slice_update(
return std::make_pair(true, broadcast_to(up, src.shape())); return std::make_pair(true, broadcast_to(up, src.shape()));
} }
// Process entries int unspecified = src.ndim() - non_none_indices;
mx::Shape up_reshape(src.ndim()); std::vector<int> squeeze_dims;
int ax = src.ndim() - 1; std::vector<int> expand_dims;
int up_ax = up.ndim() - 1; for (int i = indices.size() - 1,
for (; ax >= non_none_indices; ax--) { ax = non_none_indices - 1,
if (up_ax >= 0) { upd_ax = upd.ndim() - unspecified - 1;
up_reshape[ax] = up.shape(up_ax); i >= 0;
up_ax--; --i) {
} else {
up_reshape[ax] = 1;
}
}
for (int i = indices.size() - 1; i >= 0; --i) {
auto& pyidx = indices[i]; auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) { if (nb::isinstance<nb::slice>(pyidx)) {
get_slice_params( get_slice_params(
@ -867,19 +861,26 @@ auto mlx_slice_update(
strides[ax], strides[ax],
nb::cast<nb::slice>(pyidx), nb::cast<nb::slice>(pyidx),
src.shape(ax)); src.shape(ax));
up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1;
ax--; ax--;
upd_ax--;
} else if (nb::isinstance<nb::int_>(pyidx)) { } else if (nb::isinstance<nb::int_>(pyidx)) {
int st = nb::cast<int>(pyidx); int st = nb::cast<int>(pyidx);
st = (st < 0) ? st + src.shape(ax) : st; st = (st < 0) ? st + src.shape(i) : st;
starts[ax] = st; starts[ax] = st;
stops[ax] = st + 1; stops[ax] = st + 1;
up_reshape[ax] = 1; if (upd_ax >= 0) {
expand_dims.push_back(i - indices.size() - unspecified);
}
ax--; ax--;
} else if (pyidx.is_none()) {
if (upd_ax-- >= 0) {
squeeze_dims.push_back(i - indices.size() - unspecified);
}
} }
} }
up = reshape(up, std::move(up_reshape)); up = mx::squeeze(
mx::expand_dims(up, std::move(expand_dims)), std::move(squeeze_dims));
auto out = slice_update(src, up, starts, stops, strides); auto out = slice_update(src, up, starts, stops, strides);
return std::make_pair(true, out); return std::make_pair(true, out);
} }

View File

@ -5004,4 +5004,81 @@ void init_ops(nb::module_& m) {
Returns: Returns:
array: The imaginary part of ``a``. array: The imaginary part of ``a``.
)pbdoc"); )pbdoc");
m.def(
"slice",
[](const mx::array& a,
const mx::array& start_indices,
std::vector<int> axes,
mx::Shape slice_size,
mx::StreamOrDevice s) {
return mx::slice(
a, start_indices, std::move(axes), std::move(slice_size), s);
},
nb::arg(),
"start_indices"_a,
"axes"_a,
"slice_size"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def slice(a: array, start_indices: array, axes: Sequence[int], slice_size: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Extract a sub-array from the input array.
Args:
a (array): Input array
start_indices (array): The index location to start the slice at.
axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.
slice_size (tuple(int)): The size of the slice.
Returns:
array: The sliced output array.
Example:
>>> a = mx.array([[1, 2, 3], [4, 5, 6]])
>>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2))
array([[4, 5]], dtype=int32)
>>>
>>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1))
array([[2],
[5]], dtype=int32)
)pbdoc");
m.def(
"slice_update",
[](const mx::array& src,
const mx::array& update,
const mx::array& start_indices,
std::vector<int> axes,
mx::StreamOrDevice s) {
return mx::slice_update(src, update, start_indices, axes, s);
},
nb::arg(),
"update"_a,
"start_indices"_a,
"axes"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def slice_update(a: array, update: array, start_indices: array, axes: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Update a sub-array of the input array.
Args:
a (array): The input array to update
update (array): The update array.
start_indices (array): The index location to start the slice at.
axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.
Returns:
array: The output array with the same shape and type as the input.
Example:
>>> a = mx.zeros((3, 3))
>>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1))
array([[0, 0, 0],
[0, 1, 0],
[0, 1, 0]], dtype=float32)
)pbdoc");
} }

View File

@ -1327,6 +1327,16 @@ class TestArray(mlx_tests.MLXTestCase):
x[0, 0] = 1 x[0, 0] = 1
self.assertTrue(mx.array_equal(x[0, 0], mx.ones((2, 2, 2, 2)))) self.assertTrue(mx.array_equal(x[0, 0], mx.ones((2, 2, 2, 2))))
a = mx.zeros((2, 2, 2))
with self.assertRaises(ValueError):
a[:, None, :] = mx.ones((2, 2, 2))
# Ok, doesn't throw
a[:, None, :] = mx.ones((2, 1, 2, 2))
a[:, None, :] = mx.ones((2, 2))
a[:, None, 0] = mx.ones((2,))
a[:, None, 0] = mx.ones((1, 2))
def test_array_at(self): def test_array_at(self):
a = mx.array(1) a = mx.array(1)
a = a.at[None].add(1) a = a.at[None].add(1)

View File

@ -2769,6 +2769,19 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.imag(z).dtype, mx.float32) self.assertEqual(mx.imag(z).dtype, mx.float32)
self.assertTrue(mx.array_equal(mx.imag(z), y)) self.assertTrue(mx.array_equal(mx.imag(z), y))
def test_dynamic_slicing(self):
x = mx.random.randint(0, 100, shape=(4, 4, 4))
expected = x[1:, 2:, 3:]
out = mx.slice(x, mx.array([1, 2, 3]), (0, 1, 2), (3, 2, 1))
self.assertTrue(mx.array_equal(expected, out))
x = mx.zeros(shape=(4, 4, 4))
update = mx.random.randint(0, 100, shape=(3, 2, 1))
out = mx.slice_update(x, update, mx.array([1, 2, 3]), (0, 1, 2))
expected = mx.zeros_like(x)
expected[1:, 2:, 3:] = update
self.assertTrue(mx.array_equal(expected, out))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1291,3 +1291,24 @@ TEST_CASE("test grad types") {
} }
} }
} }
TEST_CASE("test grad dynamic slices") {
{
auto fn = [](const array& x) { return slice(x, array({0}), {0}, {1, 2}); };
auto x = array({1, 2, 3, 4}, {2, 2});
auto out = vjp(fn, x, array({1, 1}, {1, 2})).second;
CHECK(array_equal(out, array({1, 1, 0, 0}, {2, 2})).item<bool>());
}
{
auto fn = [](const std::vector<array>& inputs) {
const auto& x = inputs[0];
const auto& update = inputs[1];
return std::vector<array>{slice_update(x, update, array({0}), {0})};
};
auto x = zeros({2, 2});
auto update = array({3.f, 4.f}, {1, 2});
auto outs = vjp(fn, {x, update}, {ones({2, 2})}).second;
CHECK(allclose(outs[0], array({0.f, 0.f, 1.f, 1.f}, {2, 2})).item<bool>());
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
}
}

View File

@ -250,7 +250,7 @@ TEST_CASE("test QR factorization") {
// Unsupported types throw // Unsupported types throw
CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2}))); CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2})));
array A = array({{2., 3., 1., 2.}, {2, 2}}); array A = array({2., 3., 1., 2.}, {2, 2});
auto [Q, R] = linalg::qr(A, Device::cpu); auto [Q, R] = linalg::qr(A, Device::cpu);
auto out = matmul(Q, R); auto out = matmul(Q, R);
CHECK(allclose(out, A).item<bool>()); CHECK(allclose(out, A).item<bool>());

View File

@ -353,6 +353,50 @@ TEST_CASE("test slice update") {
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>()); CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
} }
TEST_CASE("test dynamic slice") {
auto src = reshape(arange(6), {2, 3});
CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1}));
CHECK_THROWS(slice(src, array({1, 0}), {0}, {1, 1}));
CHECK_THROWS(slice(src, array({1}), {3}, {1, 1}));
CHECK_THROWS(slice(src, array({1, 0}), {0, 0}, {1, 1}));
CHECK_THROWS(slice(src, array({1}), {0}, {2, 4}));
CHECK_THROWS(slice(src, array({1.0f}, float32), {0}, {1, 1}));
auto out = slice(src, array({1}), {0}, {1, 2});
auto expected = array({3, 4}, {1, 2});
CHECK(array_equal(out, expected).item<bool>());
out = slice(src, array({1, 1}), {0, 1}, {1, 2});
expected = array({4, 5}, {1, 2});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test dynamic slice update") {
auto src = zeros({2, 3}, int32);
auto upd = ones({1, 2}, int32);
CHECK_THROWS(slice_update(src, upd, array({1, 0, 0}), {0, 0, 0}));
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0}));
CHECK_THROWS(slice_update(src, upd, array({1}), {3}));
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0, 0}));
upd = ones({4}, int32);
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
upd = ones({1, 4}, int32);
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
CHECK_THROWS(slice_update(src, upd, array({1.0f}, float32), {0}));
upd = ones({1, 2}, int32);
auto out = slice_update(src, upd, array({1}), {0});
auto expected = reshape(array({0, 0, 0, 1, 1, 0}), {2, 3});
CHECK(array_equal(out, expected).item<bool>());
upd = ones({1, 2}, int32);
out = slice_update(src, upd, array({1, 1}), {0, 1});
expected = reshape(array({0, 0, 0, 0, 1, 1}), {2, 3});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test split") { TEST_CASE("test split") {
array x = array(1); array x = array(1);
CHECK_THROWS(split(x, 0)); CHECK_THROWS(split(x, 0));
@ -720,7 +764,7 @@ TEST_CASE("test is inf") {
CHECK_FALSE(any(isinf(z)).item<bool>()); CHECK_FALSE(any(isinf(z)).item<bool>());
array w = array({1.0f, inf, 2.0f}); array w = array({1.0f, inf, 2.0f});
CHECK(array_equal({false, true, false}, isinf(w)).item<bool>()); CHECK(array_equal(array({false, true, false}), isinf(w)).item<bool>());
array a(1.0f, bfloat16); array a(1.0f, bfloat16);
CHECK_FALSE(isinf(a).item<bool>()); CHECK_FALSE(isinf(a).item<bool>());

View File

@ -686,7 +686,7 @@ TEST_CASE("test laplace") {
CHECK(std::abs(sample_variance - expected_variance) < 0.01); CHECK(std::abs(sample_variance - expected_variance) < 0.01);
// Expected kurtosis of Laplace distribution is 3. // Expected kurtosis of Laplace distribution is 3.
array fourth_pows = power(out - sample_mean, {4}); array fourth_pows = power(out - sample_mean, array(4));
float sample_kurtosis = float sample_kurtosis =
mean(fourth_pows).item<float>() / std::pow(sample_variance, 2) - 3; mean(fourth_pows).item<float>() / std::pow(sample_variance, 2) - 3;
float expected_kurtosis = 3.0; float expected_kurtosis = 3.0;

View File

@ -496,3 +496,33 @@ TEST_CASE("test vmap SVD") {
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)}); CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
} }
} }
TEST_CASE("test vmap dynamic slices") {
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{slice(inputs[0], array({1}), {0}, {2})};
};
auto x = reshape(arange(12), {3, 4});
auto out = vmap(fun)({x})[0];
CHECK(array_equal(out, array({1, 2, 5, 6, 9, 10}, {3, 2})).item<bool>());
out = vmap(fun, /* in_axes */ {1}, /* out_axes */ {1})({x})[0];
CHECK(array_equal(out, array({4, 5, 6, 7, 8, 9, 10, 11}, {2, 4}))
.item<bool>());
}
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{
slice_update(inputs[0], inputs[1], array({1}), {0})};
};
auto x = zeros({2, 2});
auto upd = ones({2, 1});
auto out = vmap(fun)({x, upd})[0];
CHECK(array_equal(out, array({0, 1, 0, 1}, {2, 2})).item<bool>());
out = vmap(fun, /* in_axes */ {1, 0}, /* out_axes */ {1})({x, upd})[0];
CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item<bool>());
}
}