Go to the source code of this file.
|
template<typename T , typename IdxT , typename Op , int NIDX> |
METAL_FUNC void | scatter_1d_index_impl (const device T *updates, device mlx_atomic< T > *out, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *upd_shape, const constant size_t &upd_ndim, const constant size_t &upd_size, const thread array< const device IdxT *, NIDX > &idx_buffers, uint2 gid) |
|
template<typename T , typename IdxT , typename Op , int NIDX> |
METAL_FUNC void | scatter_impl (const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const thread Indices< IdxT, NIDX > &indices, uint2 gid) |
|
◆ scatter_1d_index_impl()
template<typename T , typename IdxT , typename Op , int NIDX>
METAL_FUNC void scatter_1d_index_impl |
( |
const device T * | updates, |
|
|
device mlx_atomic< T > * | out, |
|
|
const constant int * | out_shape, |
|
|
const constant size_t * | out_strides, |
|
|
const constant size_t & | out_ndim, |
|
|
const constant int * | upd_shape, |
|
|
const constant size_t & | upd_ndim, |
|
|
const constant size_t & | upd_size, |
|
|
const thread array< const device IdxT *, NIDX > & | idx_buffers, |
|
|
uint2 | gid ) |
◆ scatter_impl()
template<typename T , typename IdxT , typename Op , int NIDX>
METAL_FUNC void scatter_impl |
( |
const device T * | updates, |
|
|
device mlx_atomic< T > * | out, |
|
|
const constant int * | upd_shape, |
|
|
const constant size_t * | upd_strides, |
|
|
const constant size_t & | upd_ndim, |
|
|
const constant size_t & | upd_size, |
|
|
const constant int * | out_shape, |
|
|
const constant size_t * | out_strides, |
|
|
const constant size_t & | out_ndim, |
|
|
const constant int * | axes, |
|
|
const thread Indices< IdxT, NIDX > & | indices, |
|
|
uint2 | gid ) |