mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
parent
c6fc07f1f4
commit
b7c9f1d38f
@ -16,11 +16,6 @@ inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
@ -169,14 +164,11 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
@ -201,12 +193,142 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Gather::eval] Cannot gather with floating point indices.");
|
||||
"[Gather::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
template <typename T, typename IdxT>
|
||||
void gather_axis(
|
||||
const array& src,
|
||||
const array& ind,
|
||||
array& out,
|
||||
const int axis) {
|
||||
auto strides = ind.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = ind.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
||||
|
||||
strides = src.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
||||
|
||||
auto ind_ptr = ind.data<IdxT>();
|
||||
auto src_ptr = src.data<T>();
|
||||
auto dst_ptr = out.data<T>();
|
||||
auto ind_ax_stride = ind.strides(axis);
|
||||
auto src_ax_stride = src.strides(axis);
|
||||
auto dst_ax_stride = out.strides(axis);
|
||||
auto ind_ax_size = ind.shape(axis);
|
||||
auto src_ax_size = src.shape(axis);
|
||||
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
size_pre *= ind.shape(i);
|
||||
}
|
||||
for (int i = axis + 1; i < ind.ndim(); ++i) {
|
||||
size_post *= ind.shape(i);
|
||||
}
|
||||
size_t stride_pre = size_post * ind_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < ind_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
|
||||
dst_ptr[k + j * dst_ax_stride] =
|
||||
src_ptr[src_it.loc + ind_val * src_ax_stride];
|
||||
}
|
||||
ind_it.step();
|
||||
src_it.step();
|
||||
}
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
void dispatch_gather_axis(
|
||||
const array& src,
|
||||
const array& inds,
|
||||
array& out,
|
||||
const int axis) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather_axis<bool, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint8:
|
||||
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint16:
|
||||
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint32:
|
||||
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case uint64:
|
||||
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int8:
|
||||
gather_axis<int8_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int16:
|
||||
gather_axis<int16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int32:
|
||||
gather_axis<int32_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case int64:
|
||||
gather_axis<int64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float16:
|
||||
gather_axis<float16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float32:
|
||||
gather_axis<float, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case complex64:
|
||||
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& src = inputs[0];
|
||||
auto& inds = inputs[1];
|
||||
switch (inds.dtype()) {
|
||||
case uint8:
|
||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -296,14 +418,11 @@ void dispatch_scatter(
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
if (inds.empty()) {
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
@ -328,12 +447,9 @@ void dispatch_scatter(
|
||||
case int64:
|
||||
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
|
||||
"[Scatter::eval_cpu] Cannot scatter with indices type.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -345,7 +461,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& updates = inputs.back();
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
copy(src, out, CopyType::General);
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype);
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
@ -390,4 +508,167 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename OpT>
|
||||
void scatter_axis(
|
||||
array& out,
|
||||
const array idx,
|
||||
const array& upd,
|
||||
int axis,
|
||||
const OpT& op) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
||||
|
||||
strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
||||
|
||||
auto idx_ptr = idx.data<IdxT>();
|
||||
auto upd_ptr = upd.data<T>();
|
||||
auto dst_ptr = out.data<T>();
|
||||
auto idx_ax_stride = idx.strides(axis);
|
||||
auto upd_ax_stride = upd.strides(axis);
|
||||
auto dst_ax_stride = out.strides(axis);
|
||||
auto idx_ax_size = idx.shape(axis);
|
||||
auto dst_ax_size = out.shape(axis);
|
||||
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
size_t stride_pre = size_post * dst_ax_size;
|
||||
for (size_t i = 0; i < size_pre; i++) {
|
||||
for (size_t k = 0; k < size_post; k++) {
|
||||
for (int j = 0; j < idx_ax_size; ++j) {
|
||||
auto ind_val = offset_neg_idx(
|
||||
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
|
||||
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||
dst_ptr + k + ind_val * dst_ax_stride);
|
||||
}
|
||||
idx_it.step();
|
||||
upd_it.step();
|
||||
}
|
||||
dst_ptr += stride_pre;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
void dispatch_scatter_axis_op(
|
||||
array& out,
|
||||
const array& idx,
|
||||
const array& updates,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case ScatterAxis::None:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
scatter_axis<InT, IdxT>(
|
||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void dispatch_scatter_axis(
|
||||
array& out,
|
||||
const array& idx,
|
||||
const array& updates,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType rtype) {
|
||||
switch (idx.dtype()) {
|
||||
case uint8:
|
||||
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[ScatterAxis::eval_cpu] Cannot scatter with indices type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
auto& updates = inputs[2];
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype);
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter_axis<complex64_t>(
|
||||
out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -35,6 +35,8 @@ make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(scatter kernels/indexing.h)
|
||||
make_jit_source(gather kernels/indexing.h)
|
||||
make_jit_source(gather_axis)
|
||||
make_jit_source(scatter_axis)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if(MLX_METAL_JIT)
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@ -388,4 +389,217 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather_axis{0}{1}_{2}",
|
||||
type_to_name(out),
|
||||
type_to_name(idx),
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
kernel_name += src.flags().row_contiguous ? "c" : "nc";
|
||||
kernel_name += idx.flags().row_contiguous ? "c" : "nc";
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::gather_axis();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str = get_type_string(idx.dtype());
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
bool sc = i & 1;
|
||||
bool ic = i & 2;
|
||||
kernel_source += get_template_definition(
|
||||
lib_name + (sc ? "c" : "nc") + (ic ? "c" : "nc"),
|
||||
"gather_axis",
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
large ? "int64_t" : "int",
|
||||
sc ? "true" : "false",
|
||||
ic ? "true" : "false");
|
||||
}
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Grid [size post, index size, size pre]
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
|
||||
int idx_ax_size = idx.shape(axis_);
|
||||
auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);
|
||||
MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);
|
||||
|
||||
// Set all the buffers
|
||||
compute_encoder.set_input_array(src, 0);
|
||||
compute_encoder.set_input_array(idx, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(shape, 3);
|
||||
|
||||
auto strides = src.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 4);
|
||||
|
||||
strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(src.shape(axis_), 8);
|
||||
compute_encoder.set_bytes(src.strides(axis_), 9);
|
||||
compute_encoder.set_bytes(idx.strides(axis_), 10);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
auto& upd = inputs[2];
|
||||
|
||||
// Copy src into out
|
||||
CopyType copy_type;
|
||||
if (src.data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (src.flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(src, out, copy_type);
|
||||
|
||||
// Empty update
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
|
||||
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case ScatterAxis::None:
|
||||
op_name = "none";
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
op_name = "sum";
|
||||
break;
|
||||
}
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"scatter_axis{0}{1}_{2}_{3}",
|
||||
type_to_name(out),
|
||||
type_to_name(idx),
|
||||
op_name,
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
kernel_name += upd.flags().row_contiguous ? "c" : "nc";
|
||||
kernel_name += idx.flags().row_contiguous ? "c" : "nc";
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::reduce_utils();
|
||||
kernel_source += metal::scatter_axis();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str = get_type_string(idx.dtype());
|
||||
std::string op_type;
|
||||
switch (reduce_type_) {
|
||||
case ScatterAxis::None:
|
||||
op_type = "None";
|
||||
break;
|
||||
case ScatterAxis::Sum:
|
||||
op_type = "Sum<" + out_type_str + ">";
|
||||
break;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
bool uc = i & 1;
|
||||
bool ic = i & 2;
|
||||
kernel_source += get_template_definition(
|
||||
lib_name + (uc ? "c" : "nc") + (ic ? "c" : "nc"),
|
||||
"scatter_axis",
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
large ? "int64_t" : "int",
|
||||
op_type,
|
||||
uc ? "true" : "false",
|
||||
ic ? "true" : "false");
|
||||
}
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Grid [size post, index size, size pre]
|
||||
size_t size_pre = 1;
|
||||
size_t size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
|
||||
int idx_ax_size = idx.shape(axis_);
|
||||
auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);
|
||||
MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);
|
||||
|
||||
// Set all the buffers
|
||||
compute_encoder.set_input_array(upd, 0);
|
||||
compute_encoder.set_input_array(idx, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set source info
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(shape, 3);
|
||||
|
||||
auto strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 4);
|
||||
|
||||
strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis_);
|
||||
compute_encoder.set_vector_bytes(strides, 5);
|
||||
compute_encoder.set_bytes(ndim - 1, 6);
|
||||
compute_encoder.set_bytes(axis_, 7);
|
||||
compute_encoder.set_bytes(out.shape(axis_), 8);
|
||||
compute_encoder.set_bytes(upd.strides(axis_), 9);
|
||||
compute_encoder.set_bytes(idx.strides(axis_), 10);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -18,10 +18,12 @@ const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* fft();
|
||||
const char* gather_axis();
|
||||
const char* hadamard();
|
||||
const char* quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* scatter_axis();
|
||||
const char* softmax();
|
||||
const char* sort();
|
||||
const char* reduce();
|
||||
|
44
mlx/backend/metal/kernels/gather_axis.h
Normal file
44
mlx/backend/metal/kernels/gather_axis.h
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename T, typename IdxT, typename LocT, bool SrcC, bool IdxC>
|
||||
[[kernel]] void gather_axis(
|
||||
const device T* src [[buffer(0)]],
|
||||
const device IdxT* indices [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant int* shape [[buffer(3)]],
|
||||
const constant int64_t* src_strides [[buffer(4)]],
|
||||
const constant int64_t* idx_strides [[buffer(5)]],
|
||||
const constant size_t& ndim [[buffer(6)]],
|
||||
const constant int& axis [[buffer(7)]],
|
||||
const constant int& axis_size [[buffer(8)]],
|
||||
const constant size_t& src_ax_stride [[buffer(9)]],
|
||||
const constant size_t& idx_ax_stride [[buffer(10)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
|
||||
LocT out_idx = elem_idx * grid_dim.y + index.x;
|
||||
|
||||
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
|
||||
if (IdxC) {
|
||||
idx_loc += out_idx;
|
||||
} else {
|
||||
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
|
||||
}
|
||||
|
||||
auto idx_val = indices[idx_loc];
|
||||
if (is_signed_v<IdxT>) {
|
||||
idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;
|
||||
}
|
||||
|
||||
LocT src_idx = idx_val * static_cast<LocT>(src_ax_stride);
|
||||
if (SrcC) {
|
||||
src_idx += elem_idx * axis_size + index.x;
|
||||
} else {
|
||||
src_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, src_strides, ndim);
|
||||
}
|
||||
|
||||
out_idx += index.y * static_cast<LocT>(grid_dim.x);
|
||||
out[out_idx] = src[src_idx];
|
||||
}
|
52
mlx/backend/metal/kernels/scatter_axis.h
Normal file
52
mlx/backend/metal/kernels/scatter_axis.h
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename LocT,
|
||||
typename Op,
|
||||
bool UpdC,
|
||||
bool IdxC>
|
||||
[[kernel]] void scatter_axis(
|
||||
const device T* upd [[buffer(0)]],
|
||||
const device IdxT* indices [[buffer(1)]],
|
||||
device mlx_atomic<T>* out [[buffer(2)]],
|
||||
const constant int* shape [[buffer(3)]],
|
||||
const constant int64_t* upd_strides [[buffer(4)]],
|
||||
const constant int64_t* idx_strides [[buffer(5)]],
|
||||
const constant size_t& ndim [[buffer(6)]],
|
||||
const constant int& axis [[buffer(7)]],
|
||||
const constant int& out_axis_size [[buffer(8)]],
|
||||
const constant size_t& upd_ax_stride [[buffer(9)]],
|
||||
const constant size_t& idx_ax_stride [[buffer(10)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
Op op;
|
||||
|
||||
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
|
||||
|
||||
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
|
||||
if (IdxC) {
|
||||
idx_loc += elem_idx * grid_dim.y + index.x;
|
||||
} else {
|
||||
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
|
||||
}
|
||||
|
||||
auto idx_val = indices[idx_loc];
|
||||
if (is_signed_v<IdxT>) {
|
||||
idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;
|
||||
}
|
||||
|
||||
LocT upd_idx = index.y * static_cast<LocT>(upd_ax_stride);
|
||||
if (UpdC) {
|
||||
upd_idx += elem_idx * grid_dim.y + index.x;
|
||||
} else {
|
||||
upd_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, upd_strides, ndim);
|
||||
}
|
||||
|
||||
LocT out_idx = elem_idx * static_cast<LocT>(out_axis_size) +
|
||||
idx_val * grid_dim.x + index.x;
|
||||
op.atomic_update(out, upd[upd_idx], out_idx);
|
||||
}
|
@ -65,6 +65,7 @@ NO_CPU(Flatten)
|
||||
NO_CPU(Floor)
|
||||
NO_CPU(Full)
|
||||
NO_CPU(Gather)
|
||||
NO_CPU(GatherAxis)
|
||||
NO_CPU(GatherMM)
|
||||
NO_CPU(GatherQMM)
|
||||
NO_CPU(Greater)
|
||||
@ -98,6 +99,7 @@ NO_CPU(Reshape)
|
||||
NO_CPU(Round)
|
||||
NO_CPU(Scan)
|
||||
NO_CPU(Scatter)
|
||||
NO_CPU(ScatterAxis)
|
||||
NO_CPU(Select)
|
||||
NO_CPU(Sigmoid)
|
||||
NO_CPU(Sign)
|
||||
|
@ -65,6 +65,7 @@ NO_GPU(Flatten)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Full)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Greater)
|
||||
@ -98,6 +99,7 @@ NO_GPU(Reshape)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
|
142
mlx/ops.cpp
142
mlx/ops.cpp
@ -68,7 +68,7 @@ array indices_or_default(
|
||||
Shape shape(x.shape().begin(), x.shape().end() - 2);
|
||||
int total =
|
||||
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
return reshape(arange(total, uint32, s), shape, s);
|
||||
return reshape(arange(total, uint32, s), std::move(shape), s);
|
||||
}
|
||||
|
||||
std::pair<int, int> extract_quantized_matmul_dims(
|
||||
@ -3080,28 +3080,64 @@ array take_along_axis(
|
||||
// Allow negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<array> nd_indices;
|
||||
Shape index_shape(a.ndim(), 1);
|
||||
for (int i = 0; i < a.ndim(); ++i) {
|
||||
if (i == axis) {
|
||||
nd_indices.push_back(indices);
|
||||
} else {
|
||||
// Reshape so they can be broadcast
|
||||
index_shape[i] = a.shape(i);
|
||||
nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s));
|
||||
index_shape[i] = 1;
|
||||
}
|
||||
}
|
||||
std::vector<int> dims(a.ndim());
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
Shape slice_sizes(a.ndim(), 1);
|
||||
auto out = gather(a, nd_indices, dims, slice_sizes, s);
|
||||
// Broadcast indices and input ignoring the take axis
|
||||
auto inputs = broadcast_arrays({a, indices}, {axis - int(a.ndim())}, s);
|
||||
|
||||
// Squeeze out the slice shape
|
||||
for (auto& d : dims) {
|
||||
d += a.ndim();
|
||||
auto out_shape = inputs[1].shape();
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<GatherAxis>(to_stream(s), axis),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
array scatter_axis(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
const array& values,
|
||||
int axis,
|
||||
ScatterAxis::ReduceType mode,
|
||||
StreamOrDevice s) {
|
||||
std::string prefix =
|
||||
(mode == ScatterAxis::None) ? "[put_along_axis]" : "[scatter_add_axis]";
|
||||
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Received invalid axis " << " for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return squeeze(out, dims, s);
|
||||
|
||||
if (indices.ndim() != a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Indices of dimension " << indices.ndim()
|
||||
<< " does not match array of dimension " << a.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto upd = astype(values, a.dtype(), s);
|
||||
|
||||
// Squeeze leading singletons out of update
|
||||
if (upd.ndim() > indices.ndim()) {
|
||||
std::vector<int> sq_ax(upd.ndim() - indices.ndim());
|
||||
std::iota(sq_ax.begin(), sq_ax.end(), 0);
|
||||
upd = squeeze(upd, sq_ax, s);
|
||||
}
|
||||
|
||||
auto inputs = broadcast_arrays({indices, upd}, s);
|
||||
inputs.insert(inputs.begin(), a);
|
||||
|
||||
// Allow negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
// Broadcast src, indices, values while ignoring the take axis
|
||||
inputs = broadcast_arrays(inputs, {axis - int(a.ndim())}, s);
|
||||
|
||||
auto out_shape = inputs[0].shape();
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<ScatterAxis>(to_stream(s), mode, axis),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
array put_along_axis(
|
||||
@ -3110,45 +3146,16 @@ array put_along_axis(
|
||||
const array& values,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[put_along_axis] Received invalid axis " << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return scatter_axis(a, indices, values, axis, ScatterAxis::None, s);
|
||||
}
|
||||
|
||||
if (indices.ndim() != a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[put_along_axis] Indices of dimension " << indices.ndim()
|
||||
<< " does not match array of dimension " << a.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Allow negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<array> nd_indices;
|
||||
Shape index_shape(a.ndim(), 1);
|
||||
for (int i = 0; i < a.ndim(); ++i) {
|
||||
if (i == axis) {
|
||||
nd_indices.push_back(indices);
|
||||
} else {
|
||||
// Reshape so they can be broadcast
|
||||
index_shape[i] = a.shape(i);
|
||||
nd_indices.push_back(reshape(arange(a.shape(i), s), index_shape, s));
|
||||
index_shape[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto update = astype(broadcast_to(values, indices.shape(), s), a.dtype(), s);
|
||||
{
|
||||
auto update_shape = update.shape();
|
||||
update_shape.resize(update_shape.size() + a.ndim(), 1);
|
||||
update = reshape(update, std::move(update_shape), s);
|
||||
}
|
||||
std::vector<int> dims(a.ndim());
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
return scatter(a, nd_indices, update, dims, s);
|
||||
array scatter_add_axis(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
const array& values,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return scatter_axis(a, indices, values, axis, ScatterAxis::Sum, s);
|
||||
}
|
||||
|
||||
/** Scatter updates to given indices */
|
||||
@ -3157,8 +3164,8 @@ array scatter(
|
||||
const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType mode /*= Scatter::ReduceType::None*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
Scatter::ReduceType mode,
|
||||
StreamOrDevice s) {
|
||||
// Checks that indices, dimensions, and slice_sizes are all valid
|
||||
if (indices.size() > a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
@ -3962,6 +3969,19 @@ array gather_qmm(
|
||||
std::tie(lhs_indices, rhs_indices) =
|
||||
broadcast_arrays(lhs_indices, rhs_indices, s);
|
||||
|
||||
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[gather_qmm] Got lhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
if (!issubdtype(rhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
lhs_indices = astype(lhs_indices, uint32, s);
|
||||
rhs_indices = astype(rhs_indices, uint32, s);
|
||||
|
||||
// Compute the full output shape
|
||||
auto out_shape = lhs_indices.shape();
|
||||
out_shape.push_back(x.shape(-2));
|
||||
|
@ -968,6 +968,14 @@ array put_along_axis(
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Add the values into the array at the given indices along the axis */
|
||||
array scatter_add_axis(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
const array& values,
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Scatter updates to the given indices.
|
||||
*
|
||||
* The parameters ``indices`` and ``axes`` determine the locations of ``a``
|
||||
|
@ -2098,6 +2098,77 @@ bool Gather::is_equivalent(const Primitive& other) const {
|
||||
return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> GatherAxis::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
bool vmap_in = axes[0] >= 0;
|
||||
bool vmap_idx = axes[1] >= 0;
|
||||
|
||||
auto in = inputs[0];
|
||||
auto idx = inputs[1];
|
||||
int out_ax;
|
||||
if (vmap_in && vmap_idx) {
|
||||
// reorder the vmap axes to the same location
|
||||
idx = moveaxis(idx, axes[1], axes[0], stream());
|
||||
out_ax = axes[0];
|
||||
} else if (vmap_in) {
|
||||
// expand just the indices dimension
|
||||
idx = expand_dims(idx, axes[0], stream());
|
||||
out_ax = axes[0];
|
||||
} else if (vmap_idx) {
|
||||
// expand just the input dimension
|
||||
in = expand_dims(in, axes[1], stream());
|
||||
out_ax = axes[1];
|
||||
} else {
|
||||
out_ax = -1;
|
||||
}
|
||||
int axis = (out_ax >= 0 && axis_ >= out_ax) ? axis_ + 1 : axis_;
|
||||
return {{take_along_axis(in, idx, axis, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> GatherAxis::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (int argnum : argnums) {
|
||||
if (argnum > 0) {
|
||||
// Grads w.r.t. indices are zero
|
||||
vjps.push_back(
|
||||
zeros(primals[argnum].shape(), primals[argnum].dtype(), stream()));
|
||||
} else {
|
||||
auto src = zeros_like(primals[0], stream());
|
||||
vjps.push_back(array(
|
||||
src.shape(),
|
||||
src.dtype(),
|
||||
std::make_shared<ScatterAxis>(stream(), ScatterAxis::Sum, axis_),
|
||||
{src, primals[1], cotangents[0]}));
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> GatherAxis::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
if (argnums.size() > 1 || argnums[0] != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[gather_axis] Cannot calculate JVP with respect to indices.");
|
||||
}
|
||||
return {take_along_axis(tangents[0], primals[1], axis_, stream())};
|
||||
}
|
||||
|
||||
std::vector<Shape> GatherAxis::output_shapes(const std::vector<array>& inputs) {
|
||||
return {inputs[1].shape()};
|
||||
}
|
||||
|
||||
bool GatherAxis::is_equivalent(const Primitive& other) const {
|
||||
auto& g_other = static_cast<const GatherAxis&>(other);
|
||||
return axis_ == g_other.axis_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
|
||||
Shape out_shape;
|
||||
if (inputs.size() > 1) {
|
||||
@ -3621,6 +3692,117 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
|
||||
return {{out}, {src_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> ScatterAxis::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
const auto& indices = primals[1];
|
||||
const auto& updates = primals[2];
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto num : argnums) {
|
||||
// Gradient wrt to the input array
|
||||
if (num == 0) {
|
||||
if (reduce_type_ == ScatterAxis::None) {
|
||||
// Scatter 0s to the locations that were updated with the updates
|
||||
vjps.push_back(put_along_axis(
|
||||
cotangents[0],
|
||||
indices,
|
||||
zeros_like(updates, stream()),
|
||||
axis_,
|
||||
stream()));
|
||||
} else {
|
||||
// The input array values are kept so they all get gradients
|
||||
vjps.push_back(cotangents[0]);
|
||||
}
|
||||
} else if (num == 2) {
|
||||
vjps.push_back(take_along_axis(cotangents[0], indices, axis_, stream()));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[scatter_axis] Cannot calculate VJP with respect to indices.");
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> ScatterAxis::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 1) {
|
||||
throw std::invalid_argument(
|
||||
"[scatter_axis] Cannot calculate JVP with respect to indices.");
|
||||
}
|
||||
}
|
||||
if (argnums.size() == 2) {
|
||||
return {array(
|
||||
primals[0].shape(),
|
||||
primals[0].dtype(),
|
||||
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
|
||||
{tangents[0], primals[1], tangents[1]})};
|
||||
} else {
|
||||
auto tan_a =
|
||||
argnums[0] == 0 ? tangents[0] : zeros_like(primals[0], stream());
|
||||
auto tan_b =
|
||||
argnums[0] == 2 ? tangents[0] : zeros_like(primals[2], stream());
|
||||
return {array(
|
||||
primals[0].shape(),
|
||||
primals[0].dtype(),
|
||||
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
|
||||
{tan_a, primals[1], tan_b})};
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> ScatterAxis::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
// Find the first vmap axis
|
||||
int out_ax = -1;
|
||||
for (auto ax : axes) {
|
||||
if (ax >= 0) {
|
||||
out_ax = ax;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (out_ax < 0) {
|
||||
return {
|
||||
{array(
|
||||
inputs[0].shape(),
|
||||
inputs[0].dtype(),
|
||||
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
|
||||
inputs)},
|
||||
{-1}};
|
||||
}
|
||||
|
||||
auto v_in = inputs;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
if (axes[i] >= 0) {
|
||||
// if out_ax >= 0 move axis o/w set out_ax
|
||||
if (out_ax != axes[i]) {
|
||||
v_in[i] = moveaxis(v_in[i], axes[i], out_ax, stream());
|
||||
}
|
||||
} else {
|
||||
v_in[i] = expand_dims(v_in[i], out_ax, stream());
|
||||
}
|
||||
}
|
||||
int axis = axis_ >= out_ax ? axis_ + 1 : axis_;
|
||||
auto fn = reduce_type_ == Sum ? scatter_add_axis : put_along_axis;
|
||||
return {{fn(v_in[0], v_in[1], v_in[2], axis, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
std::vector<Shape> ScatterAxis::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
return {inputs[0].shape()};
|
||||
}
|
||||
|
||||
bool ScatterAxis::is_equivalent(const Primitive& other) const {
|
||||
auto& s_other = static_cast<const ScatterAxis&>(other);
|
||||
return reduce_type_ == s_other.reduce_type_ && axis_ == s_other.axis_;
|
||||
}
|
||||
|
||||
std::vector<array> Sigmoid::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -1095,6 +1095,27 @@ class Gather : public UnaryPrimitive {
|
||||
Shape slice_sizes_;
|
||||
};
|
||||
|
||||
class GatherAxis : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GatherAxis(Stream stream, int axis)
|
||||
: UnaryPrimitive(stream), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(GatherAxis)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return axis_;
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
};
|
||||
|
||||
class Greater : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
|
||||
@ -1786,6 +1807,41 @@ class Scatter : public UnaryPrimitive {
|
||||
std::vector<int> axes_;
|
||||
};
|
||||
|
||||
class ScatterAxis : public UnaryPrimitive {
|
||||
public:
|
||||
enum ReduceType { Sum, None };
|
||||
|
||||
explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
|
||||
: UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
os << "ScatterAxis";
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
os << " Sum";
|
||||
break;
|
||||
case None:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::pair<ReduceType, int> state() const {
|
||||
return {reduce_type_, axis_};
|
||||
}
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
int axis_;
|
||||
};
|
||||
|
||||
class Sigmoid : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
@ -669,6 +669,37 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
_, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
def test_put_along_axis_grads(self):
|
||||
a = mx.zeros((5, 1))
|
||||
b = mx.ones((2, 1))
|
||||
|
||||
def fun(a, b):
|
||||
idx = mx.array([[0], [3]])
|
||||
return mx.put_along_axis(a, idx, b, axis=0)
|
||||
|
||||
# Test VJP
|
||||
cotan = mx.full((5, 1), 2.0)
|
||||
_, (da, db) = mx.vjp(fun, (a, b), (cotan,))
|
||||
expected_da = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
|
||||
expected_db = mx.array([2.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected_da, da))
|
||||
self.assertTrue(mx.allclose(expected_db, db))
|
||||
|
||||
# Test JVP
|
||||
tan_a = mx.full((5, 1), 2.0)
|
||||
tan_b = mx.full((2, 1), 3.0)
|
||||
_, (jout,) = mx.jvp(fun, (a, b), (tan_a, tan_b))
|
||||
expected = mx.array([3.0, 2.0, 2.0, 3.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected, jout))
|
||||
|
||||
def fun(a):
|
||||
idx = mx.array([[0], [3]])
|
||||
return mx.put_along_axis(a, idx, b, axis=0)
|
||||
|
||||
_, (jout,) = mx.jvp(fun, (a,), (tan_a,))
|
||||
expected = mx.array([0.0, 2.0, 2.0, 0.0, 2.0])[:, None]
|
||||
self.assertTrue(mx.allclose(expected, jout))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1150,6 +1150,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_mlx = mx.put_along_axis(a_mlx, idx_mlx, values_mlx, axis=ax)
|
||||
self.assertTrue(np.array_equal(a_np, out_mlx))
|
||||
|
||||
source = mx.zeros((1, 1, 8, 32))
|
||||
indices = mx.array([0, 2, 4, 5]).reshape((1, 1, 4, 1))
|
||||
update = mx.array(1.0)
|
||||
|
||||
out_mlx = mx.put_along_axis(source, indices, update, axis=-2)
|
||||
out_np = np.array(source)
|
||||
np.put_along_axis(out_np, np.array(indices), np.array(update), axis=-2)
|
||||
self.assertTrue(np.array_equal(out_np, np.array(out_mlx)))
|
||||
|
||||
def test_split(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
splits = mx.split(a, 3)
|
||||
|
@ -549,6 +549,53 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2)
|
||||
self.assertTrue(mx.array_equal(out, target))
|
||||
|
||||
def test_vmap_take_along_axis(self):
|
||||
a = mx.zeros((4, 5, 1))
|
||||
idx = mx.zeros((2, 4, 1), mx.int32)
|
||||
|
||||
def fun(a, idx):
|
||||
return mx.take_along_axis(a, idx, axis=0)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, 1))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
idx = mx.zeros((2, 1), mx.int32)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, None))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
a = mx.zeros((5, 1))
|
||||
idx = mx.zeros((4, 2, 1), mx.int32)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(None, 0))(a, idx)
|
||||
self.assertEqual(out.shape, (4, 2, 1))
|
||||
|
||||
def test_vmap_put_along_axis(self):
|
||||
a = mx.zeros((4, 5, 1))
|
||||
idx = mx.ones((2, 4, 1), mx.int32)
|
||||
upd = mx.ones((2, 4, 1))
|
||||
|
||||
def fun(a, idx, upd):
|
||||
return mx.put_along_axis(a, idx, upd, axis=0)
|
||||
|
||||
out = mx.vmap(fun, in_axes=(0, 1, 1))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
upd = mx.ones((2, 1))
|
||||
out = mx.vmap(fun, in_axes=(0, 1, None))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
idx = mx.ones((2, 1), mx.int32)
|
||||
upd = mx.ones((2, 1))
|
||||
out = mx.vmap(fun, in_axes=(0, None, None))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
a = mx.zeros((5, 1))
|
||||
idx = mx.ones((2, 4, 1), mx.int32)
|
||||
upd = mx.ones((2, 4, 1))
|
||||
out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)
|
||||
self.assertEqual(out.shape, (4, 5, 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user