Go to the source code of this file.
|
template<typename T, typename IdxT, typename LocT, typename Op, bool UpdC, bool IdxC> |
void | scatter_axis (const device T *upd, const device IdxT *indices, device mlx_atomic< T > *out, const constant int *shape, const constant int64_t *upd_strides, const constant int64_t *idx_strides, const constant size_t &ndim, const constant int &axis, const constant int &out_axis_size, const constant size_t &upd_ax_stride, const constant size_t &idx_ax_stride, uint3 index, uint3 grid_dim) |
|
◆ scatter_axis()
template<typename T, typename IdxT, typename LocT, typename Op, bool UpdC, bool IdxC>
void scatter_axis |
( |
const device T * | upd, |
|
|
const device IdxT * | indices, |
|
|
device mlx_atomic< T > * | out, |
|
|
const constant int * | shape, |
|
|
const constant int64_t * | upd_strides, |
|
|
const constant int64_t * | idx_strides, |
|
|
const constant size_t & | ndim, |
|
|
const constant int & | axis, |
|
|
const constant int & | out_axis_size, |
|
|
const constant size_t & | upd_ax_stride, |
|
|
const constant size_t & | idx_ax_stride, |
|
|
uint3 | index, |
|
|
uint3 | grid_dim ) |