Cuda perf tuning (#2307)

* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
This commit is contained in:
Awni Hannun
2025-06-20 14:50:57 -07:00
committed by GitHub
parent 76831ed83d
commit c9a9180584
5 changed files with 63 additions and 24 deletions

View File

@@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) {
axis += in.ndim();
}
@@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.flags());
}
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {