mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Cuda perf tuning (#2307)
* perf tuning * fix adding inputs arrays in matmul / srot * format * fix
This commit is contained in:
@@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array out = out_;
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
@@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.flags());
|
||||
}
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||
|
||||
Reference in New Issue
Block a user