scatter axis + gather axis primitives (#1813)

* scatter axis + gather axis primitives

* add transforms

* comment
This commit is contained in:
Awni Hannun
2025-01-31 20:48:08 -08:00
committed by GitHub
parent c6fc07f1f4
commit b7c9f1d38f
15 changed files with 1037 additions and 85 deletions

View File

@@ -35,6 +35,8 @@ make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(gather_axis)
make_jit_source(scatter_axis)
make_jit_source(hadamard)
if(MLX_METAL_JIT)