diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index cf957bd02..4ed8d0607 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -1,6 +1,5 @@ // Copyright © 2025 Apple Inc. -#include "mlx/utils.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/worker.h" diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index f7b7e60b7..5b86961da 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -362,9 +362,18 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { a_batch_strides.back(), b_batch_strides.back()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + auto nbatch = batch_count / batch_shape.back(); + if (nbatch == 1) { + matmul.run(encoder, out.data(), a.data(), b.data()); + return; + } + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); - for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { + for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M * N, @@ -448,10 +457,21 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { b_batch_strides.back(), c_batch_strides.back()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + auto nbatch = batch_count / batch_shape.back(); + if (nbatch == 1) { + matmul.run(encoder, out.data(), a.data(), b.data(), c.data(), alpha_, beta_); + return; + } + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); - for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { + for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M * N, diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 154ca5f32..5cbffc0f4 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array out = out_; auto& encoder = cu::get_command_encoder(s); - encoder.set_input_array(in); - encoder.set_output_array(out); - if (axis < 0) { axis += in.ndim(); } @@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { in.flags()); } + encoder.set_input_array(in); + encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { if constexpr (!std::is_same_v) {