scatter axis + gather axis primitives (#1813)

* scatter axis + gather axis primitives

* add transforms

* comment
This commit is contained in:
Awni Hannun
2025-01-31 20:48:08 -08:00
committed by GitHub
parent c6fc07f1f4
commit b7c9f1d38f
15 changed files with 1037 additions and 85 deletions

View File

@@ -16,11 +16,6 @@ inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
@@ -169,14 +164,11 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
std::vector<array> inds(inputs.begin() + 1, inputs.end());
if (inds.empty()) {
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
return;
}
switch (inds[0].dtype()) {
case bool_:
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
break;
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
@@ -201,12 +193,142 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
case float16:
case float32:
case bfloat16:
case complex64:
default:
throw std::runtime_error(
"[Gather::eval] Cannot gather with floating point indices.");
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
}
template <typename T, typename IdxT>
void gather_axis(
const array& src,
const array& ind,
array& out,
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
auto dst_ptr = out.data<T>();
auto ind_ax_stride = ind.strides(axis);
auto src_ax_stride = src.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto ind_ax_size = ind.shape(axis);
auto src_ax_size = src.shape(axis);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= ind.shape(i);
}
for (int i = axis + 1; i < ind.ndim(); ++i) {
size_post *= ind.shape(i);
}
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < ind_ax_size; ++j) {
auto ind_val = offset_neg_idx(
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
dst_ptr[k + j * dst_ax_stride] =
src_ptr[src_it.loc + ind_val * src_ax_stride];
}
ind_it.step();
src_it.step();
}
dst_ptr += stride_pre;
}
}
template <typename IdxT>
void dispatch_gather_axis(
const array& src,
const array& inds,
array& out,
const int axis) {
switch (out.dtype()) {
case bool_:
gather_axis<bool, IdxT>(src, inds, out, axis);
break;
case uint8:
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
break;
case uint16:
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
break;
case uint32:
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
break;
case uint64:
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
break;
case int8:
gather_axis<int8_t, IdxT>(src, inds, out, axis);
break;
case int16:
gather_axis<int16_t, IdxT>(src, inds, out, axis);
break;
case int32:
gather_axis<int32_t, IdxT>(src, inds, out, axis);
break;
case int64:
gather_axis<int64_t, IdxT>(src, inds, out, axis);
break;
case float16:
gather_axis<float16_t, IdxT>(src, inds, out, axis);
break;
case float32:
gather_axis<float, IdxT>(src, inds, out, axis);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break;
case complex64:
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
break;
}
}
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
break;
}
}
@@ -296,14 +418,11 @@ void dispatch_scatter(
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
if (inds.empty()) {
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
return;
}
switch (inds[0].dtype()) {
case bool_:
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
break;
case uint8:
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
break;
@@ -328,12 +447,9 @@ void dispatch_scatter(
case int64:
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
break;
case float16:
case float32:
case bfloat16:
case complex64:
default:
throw std::runtime_error(
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
"[Scatter::eval_cpu] Cannot scatter with indices type.");
}
}
@@ -345,7 +461,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out)
copy(src, out, CopyType::General);
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
switch (src.dtype()) {
case bool_:
@@ -390,4 +508,167 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
template <typename T, typename IdxT, typename OpT>
void scatter_axis(
array& out,
const array idx,
const array& upd,
int axis,
const OpT& op) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();
auto dst_ptr = out.data<T>();
auto idx_ax_stride = idx.strides(axis);
auto upd_ax_stride = upd.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto idx_ax_size = idx.shape(axis);
auto dst_ax_size = out.shape(axis);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= idx.shape(i);
}
for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
}
idx_it.step();
upd_it.step();
}
dst_ptr += stride_pre;
}
}
template <typename InT, typename IdxT>
void dispatch_scatter_axis_op(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (rtype) {
case ScatterAxis::None:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
break;
case ScatterAxis::Sum:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
break;
}
}
template <typename InT>
void dispatch_scatter_axis(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (idx.dtype()) {
case uint8:
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
break;
case uint16:
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
break;
case uint32:
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
break;
case uint64:
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
break;
case int8:
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
break;
case int16:
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
break;
case int32:
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
break;
case int64:
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
break;
default:
throw std::runtime_error(
"[ScatterAxis::eval_cpu] Cannot scatter with indices type.");
}
}
void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
auto& idx = inputs[1];
auto& updates = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
switch (src.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break;
case uint8:
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint16:
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint32:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint64:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break;
case int8:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break;
case int16:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break;
case int32:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break;
case int64:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break;
case float16:
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
break;
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
break;
}
}
} // namespace mlx::core

View File

@@ -35,6 +35,8 @@ make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(gather_axis)
make_jit_source(scatter_axis)
make_jit_source(hadamard)
if(MLX_METAL_JIT)

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -388,4 +389,217 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
auto& idx = inputs[1];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& s = stream();
auto& d = metal::device(s.device);
size_t ndim = src.ndim();
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
std::string kernel_name = fmt::format(
"gather_axis{0}{1}_{2}",
type_to_name(out),
type_to_name(idx),
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
kernel_name += src.flags().row_contiguous ? "c" : "nc";
kernel_name += idx.flags().row_contiguous ? "c" : "nc";
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::gather_axis();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str = get_type_string(idx.dtype());
for (int i = 0; i < 4; ++i) {
bool sc = i & 1;
bool ic = i & 2;
kernel_source += get_template_definition(
lib_name + (sc ? "c" : "nc") + (ic ? "c" : "nc"),
"gather_axis",
out_type_str,
idx_type_str,
large ? "int64_t" : "int",
sc ? "true" : "false",
ic ? "true" : "false");
}
return kernel_source;
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel);
// Grid [size post, index size, size pre]
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis_; ++i) {
size_pre *= idx.shape(i);
}
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
int idx_ax_size = idx.shape(axis_);
auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);
MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);
// Set all the buffers
compute_encoder.set_input_array(src, 0);
compute_encoder.set_input_array(idx, 1);
compute_encoder.set_output_array(out, 2);
// Set source info
auto shape = idx.shape();
shape.erase(shape.begin() + axis_);
compute_encoder.set_vector_bytes(shape, 3);
auto strides = src.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 4);
strides = idx.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(src.shape(axis_), 8);
compute_encoder.set_bytes(src.strides(axis_), 9);
compute_encoder.set_bytes(idx.strides(axis_), 10);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
auto& idx = inputs[1];
auto& upd = inputs[2];
// Copy src into out
CopyType copy_type;
if (src.data_size() == 1) {
copy_type = CopyType::Scalar;
} else if (src.flags().row_contiguous) {
copy_type = CopyType::Vector;
} else {
copy_type = CopyType::General;
}
copy_gpu(src, out, copy_type);
// Empty update
if (upd.size() == 0) {
return;
}
auto& s = stream();
auto& d = metal::device(s.device);
size_t ndim = src.ndim();
bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
std::string op_name;
switch (reduce_type_) {
case ScatterAxis::None:
op_name = "none";
break;
case ScatterAxis::Sum:
op_name = "sum";
break;
}
std::string kernel_name = fmt::format(
"scatter_axis{0}{1}_{2}_{3}",
type_to_name(out),
type_to_name(idx),
op_name,
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
kernel_name += upd.flags().row_contiguous ? "c" : "nc";
kernel_name += idx.flags().row_contiguous ? "c" : "nc";
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::reduce_utils();
kernel_source += metal::scatter_axis();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str = get_type_string(idx.dtype());
std::string op_type;
switch (reduce_type_) {
case ScatterAxis::None:
op_type = "None";
break;
case ScatterAxis::Sum:
op_type = "Sum<" + out_type_str + ">";
break;
}
for (int i = 0; i < 4; ++i) {
bool uc = i & 1;
bool ic = i & 2;
kernel_source += get_template_definition(
lib_name + (uc ? "c" : "nc") + (ic ? "c" : "nc"),
"scatter_axis",
out_type_str,
idx_type_str,
large ? "int64_t" : "int",
op_type,
uc ? "true" : "false",
ic ? "true" : "false");
}
return kernel_source;
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel);
// Grid [size post, index size, size pre]
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis_; ++i) {
size_pre *= idx.shape(i);
}
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
int idx_ax_size = idx.shape(axis_);
auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);
MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);
// Set all the buffers
compute_encoder.set_input_array(upd, 0);
compute_encoder.set_input_array(idx, 1);
compute_encoder.set_output_array(out, 2);
// Set source info
auto shape = idx.shape();
shape.erase(shape.begin() + axis_);
compute_encoder.set_vector_bytes(shape, 3);
auto strides = upd.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 4);
strides = idx.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(out.shape(axis_), 8);
compute_encoder.set_bytes(upd.strides(axis_), 9);
compute_encoder.set_bytes(idx.strides(axis_), 10);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -18,10 +18,12 @@ const char* binary();
const char* binary_two();
const char* copy();
const char* fft();
const char* gather_axis();
const char* hadamard();
const char* quantized();
const char* ternary();
const char* scan();
const char* scatter_axis();
const char* softmax();
const char* sort();
const char* reduce();

View File

@@ -0,0 +1,44 @@
// Copyright © 2025 Apple Inc.
#pragma once
template <typename T, typename IdxT, typename LocT, bool SrcC, bool IdxC>
[[kernel]] void gather_axis(
const device T* src [[buffer(0)]],
const device IdxT* indices [[buffer(1)]],
device T* out [[buffer(2)]],
const constant int* shape [[buffer(3)]],
const constant int64_t* src_strides [[buffer(4)]],
const constant int64_t* idx_strides [[buffer(5)]],
const constant size_t& ndim [[buffer(6)]],
const constant int& axis [[buffer(7)]],
const constant int& axis_size [[buffer(8)]],
const constant size_t& src_ax_stride [[buffer(9)]],
const constant size_t& idx_ax_stride [[buffer(10)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
LocT out_idx = elem_idx * grid_dim.y + index.x;
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
if (IdxC) {
idx_loc += out_idx;
} else {
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
}
auto idx_val = indices[idx_loc];
if (is_signed_v<IdxT>) {
idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;
}
LocT src_idx = idx_val * static_cast<LocT>(src_ax_stride);
if (SrcC) {
src_idx += elem_idx * axis_size + index.x;
} else {
src_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, src_strides, ndim);
}
out_idx += index.y * static_cast<LocT>(grid_dim.x);
out[out_idx] = src[src_idx];
}

View File

@@ -0,0 +1,52 @@
// Copyright © 2025 Apple Inc.
#pragma once
template <
typename T,
typename IdxT,
typename LocT,
typename Op,
bool UpdC,
bool IdxC>
[[kernel]] void scatter_axis(
const device T* upd [[buffer(0)]],
const device IdxT* indices [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* shape [[buffer(3)]],
const constant int64_t* upd_strides [[buffer(4)]],
const constant int64_t* idx_strides [[buffer(5)]],
const constant size_t& ndim [[buffer(6)]],
const constant int& axis [[buffer(7)]],
const constant int& out_axis_size [[buffer(8)]],
const constant size_t& upd_ax_stride [[buffer(9)]],
const constant size_t& idx_ax_stride [[buffer(10)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
Op op;
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
if (IdxC) {
idx_loc += elem_idx * grid_dim.y + index.x;
} else {
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
}
auto idx_val = indices[idx_loc];
if (is_signed_v<IdxT>) {
idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;
}
LocT upd_idx = index.y * static_cast<LocT>(upd_ax_stride);
if (UpdC) {
upd_idx += elem_idx * grid_dim.y + index.x;
} else {
upd_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, upd_strides, ndim);
}
LocT out_idx = elem_idx * static_cast<LocT>(out_axis_size) +
idx_val * grid_dim.x + index.x;
op.atomic_update(out, upd[upd_idx], out_idx);
}

View File

@@ -65,6 +65,7 @@ NO_CPU(Flatten)
NO_CPU(Floor)
NO_CPU(Full)
NO_CPU(Gather)
NO_CPU(GatherAxis)
NO_CPU(GatherMM)
NO_CPU(GatherQMM)
NO_CPU(Greater)
@@ -98,6 +99,7 @@ NO_CPU(Reshape)
NO_CPU(Round)
NO_CPU(Scan)
NO_CPU(Scatter)
NO_CPU(ScatterAxis)
NO_CPU(Select)
NO_CPU(Sigmoid)
NO_CPU(Sign)

View File

@@ -65,6 +65,7 @@ NO_GPU(Flatten)
NO_GPU(Floor)
NO_GPU(Full)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
@@ -98,6 +99,7 @@ NO_GPU(Reshape)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)