// Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/kernels/atomic_ops.cuh" namespace mlx::core::cu { struct ScatterAssign { template __device__ void operator()(T* out, T val) const { *out = val; } }; struct ScatterSum { template __device__ void operator()(T* out, T val) const { atomic_add(out, val); } }; struct ScatterProd { template __device__ void operator()(T* out, T val) const { atomic_prod(out, val); } }; struct ScatterMax { template __device__ void operator()(T* out, T val) const { atomic_max(out, val); } }; struct ScatterMin { template __device__ void operator()(T* out, T val) const { atomic_min(out, val); } }; } // namespace mlx::core::cu