CUDA backend: indexing ops (#2277)

This commit is contained in:
Cheng
2025-06-13 13:44:19 +09:00
committed by GitHub
parent 2188199ff8
commit c8b4787e4e
11 changed files with 830 additions and 6 deletions

View File

@@ -0,0 +1,44 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device/atomic_ops.cuh"
namespace mlx::core::cu {
struct ScatterAssign {
template <typename T>
__device__ void operator()(T* out, T val) const {
*out = val;
}
};
struct ScatterSum {
template <typename T>
__device__ void operator()(T* out, T val) const {
atomic_add(out, val);
}
};
struct ScatterProd {
template <typename T>
__device__ void operator()(T* out, T val) const {
atomic_prod(out, val);
}
};
struct ScatterMax {
template <typename T>
__device__ void operator()(T* out, T val) const {
atomic_max(out, val);
}
};
struct ScatterMin {
template <typename T>
__device__ void operator()(T* out, T val) const {
atomic_min(out, val);
}
};
} // namespace mlx::core::cu