MLX
 
Loading...
Searching...
No Matches
scatter_axis.h File Reference

Go to the source code of this file.

Functions

template<typename T, typename IdxT, typename LocT, typename Op, bool UpdC, bool IdxC>
void scatter_axis (const device T *upd, const device IdxT *indices, device mlx_atomic< T > *out, const constant int *shape, const constant int64_t *upd_strides, const constant int64_t *idx_strides, const constant size_t &ndim, const constant int &axis, const constant int &out_axis_size, const constant size_t &upd_ax_stride, const constant size_t &idx_ax_stride, uint3 index, uint3 grid_dim)
 

Function Documentation

◆ scatter_axis()

template<typename T, typename IdxT, typename LocT, typename Op, bool UpdC, bool IdxC>
void scatter_axis ( const device T * upd,
const device IdxT * indices,
device mlx_atomic< T > * out,
const constant int * shape,
const constant int64_t * upd_strides,
const constant int64_t * idx_strides,
const constant size_t & ndim,
const constant int & axis,
const constant int & out_axis_size,
const constant size_t & upd_ax_stride,
const constant size_t & idx_ax_stride,
uint3 index,
uint3 grid_dim )