From 913b19329c6620a22c5f539d2d55f89ec40da037 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 29 Mar 2024 22:48:29 +0900 Subject: [PATCH] Add missing && when forwarding args (#925) Without the && args would be copied and perfect forwarding won't work. --- benchmarks/cpp/time_utils.h | 2 +- mlx/backend/common/copy.cpp | 62 ++++++++++++++++++------------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/benchmarks/cpp/time_utils.h b/benchmarks/cpp/time_utils.h index 780f1867f..09ba6c173 100644 --- a/benchmarks/cpp/time_utils.h +++ b/benchmarks/cpp/time_utils.h @@ -24,7 +24,7 @@ << std::endl; template -double time_fn(F fn, Args... args) { +double time_fn(F fn, Args&&... args) { // warmup for (int i = 0; i < 5; ++i) { eval(fn(std::forward(args)...)); diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index 53956041a..2272ff325 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -272,7 +272,7 @@ inline void copy_general_general(const array& src, array& dst) { } template -void copy(const array& src, array& dst, CopyType ctype, Args... args) { +void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (ctype) { case CopyType::Scalar: copy_single(src, dst); @@ -281,54 +281,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args... args) { copy_vector(src, dst); return; case CopyType::General: - copy_general(src, dst, args...); + copy_general(src, dst, std::forward(args)...); return; case CopyType::GeneralGeneral: - copy_general_general(src, dst, args...); + copy_general_general(src, dst, std::forward(args)...); } } template -void copy(const array& src, array& dst, CopyType ctype, Args... args) { +void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (dst.dtype()) { case bool_: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case uint8: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case uint16: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case uint32: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case uint64: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case int8: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case int16: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case int32: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case int64: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case float16: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case float32: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case complex64: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; } } @@ -338,46 +338,46 @@ inline void copy_inplace_dispatch( const array& src, array& dst, CopyType ctype, - Args... args) { + Args&&... args) { switch (src.dtype()) { case bool_: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case uint8: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case uint16: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case uint32: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case uint64: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case int8: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case int16: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case int32: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case int64: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case float16: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case float32: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case bfloat16: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; case complex64: - copy(src, dst, ctype, args...); + copy(src, dst, ctype, std::forward(args)...); break; } }