MLX
Loading...
Searching...
No Matches
scatter.h File Reference

Go to the source code of this file.

Functions

template<typename T , typename IdxT , typename Op , int NIDX>
METAL_FUNC void scatter_1d_index_impl (const device T *updates, device mlx_atomic< T > *out, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *upd_shape, const constant size_t &upd_ndim, const constant size_t &upd_size, const thread array< const device IdxT *, NIDX > &idx_buffers, uint2 gid)
 
template<typename T , typename IdxT , typename Op , int NIDX>
METAL_FUNC void scatter_impl (const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const thread Indices< IdxT, NIDX > &indices, uint2 gid)
 

Function Documentation

◆ scatter_1d_index_impl()

template<typename T , typename IdxT , typename Op , int NIDX>
METAL_FUNC void scatter_1d_index_impl ( const device T * updates,
device mlx_atomic< T > * out,
const constant int * out_shape,
const constant size_t * out_strides,
const constant size_t & out_ndim,
const constant int * upd_shape,
const constant size_t & upd_ndim,
const constant size_t & upd_size,
const thread array< const device IdxT *, NIDX > & idx_buffers,
uint2 gid )

◆ scatter_impl()

template<typename T , typename IdxT , typename Op , int NIDX>
METAL_FUNC void scatter_impl ( const device T * updates,
device mlx_atomic< T > * out,
const constant int * upd_shape,
const constant size_t * upd_strides,
const constant size_t & upd_ndim,
const constant size_t & upd_size,
const constant int * out_shape,
const constant size_t * out_strides,
const constant size_t & out_ndim,
const constant int * axes,
const thread Indices< IdxT, NIDX > & indices,
uint2 gid )