Add missing && when forwarding args (#925)

Without the && args would be copied and perfect forwarding won't work.
This commit is contained in:
Cheng 2024-03-29 22:48:29 +09:00 committed by GitHub
parent d8cb3128f6
commit 913b19329c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 32 deletions

View File

@ -24,7 +24,7 @@
<< std::endl; << std::endl;
template <typename F, typename... Args> template <typename F, typename... Args>
double time_fn(F fn, Args... args) { double time_fn(F fn, Args&&... args) {
// warmup // warmup
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...)); eval(fn(std::forward<Args>(args)...));

View File

@ -272,7 +272,7 @@ inline void copy_general_general(const array& src, array& dst) {
} }
template <typename SrcT, typename DstT, typename... Args> template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) { void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst); copy_single<SrcT, DstT>(src, dst);
@ -281,54 +281,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args... args) {
copy_vector<SrcT, DstT>(src, dst); copy_vector<SrcT, DstT>(src, dst);
return; return;
case CopyType::General: case CopyType::General:
copy_general<SrcT, DstT>(src, dst, args...); copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return; return;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, args...); copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
} }
} }
template <typename SrcT, typename... Args> template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) { void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (dst.dtype()) { switch (dst.dtype()) {
case bool_: case bool_:
copy<SrcT, bool>(src, dst, ctype, args...); copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint8: case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, args...); copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint16: case uint16:
copy<SrcT, uint16_t>(src, dst, ctype, args...); copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint32: case uint32:
copy<SrcT, uint32_t>(src, dst, ctype, args...); copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint64: case uint64:
copy<SrcT, uint64_t>(src, dst, ctype, args...); copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int8: case int8:
copy<SrcT, int8_t>(src, dst, ctype, args...); copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int16: case int16:
copy<SrcT, int16_t>(src, dst, ctype, args...); copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int32: case int32:
copy<SrcT, int32_t>(src, dst, ctype, args...); copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int64: case int64:
copy<SrcT, int64_t>(src, dst, ctype, args...); copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float16: case float16:
copy<SrcT, float16_t>(src, dst, ctype, args...); copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float32: case float32:
copy<SrcT, float>(src, dst, ctype, args...); copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case bfloat16: case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, args...); copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case complex64: case complex64:
copy<SrcT, complex64_t>(src, dst, ctype, args...); copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
} }
} }
@ -338,46 +338,46 @@ inline void copy_inplace_dispatch(
const array& src, const array& src,
array& dst, array& dst,
CopyType ctype, CopyType ctype,
Args... args) { Args&&... args) {
switch (src.dtype()) { switch (src.dtype()) {
case bool_: case bool_:
copy<bool>(src, dst, ctype, args...); copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint8: case uint8:
copy<uint8_t>(src, dst, ctype, args...); copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint16: case uint16:
copy<uint16_t>(src, dst, ctype, args...); copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint32: case uint32:
copy<uint32_t>(src, dst, ctype, args...); copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case uint64: case uint64:
copy<uint64_t>(src, dst, ctype, args...); copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int8: case int8:
copy<int8_t>(src, dst, ctype, args...); copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int16: case int16:
copy<int16_t>(src, dst, ctype, args...); copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int32: case int32:
copy<int32_t>(src, dst, ctype, args...); copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case int64: case int64:
copy<int64_t>(src, dst, ctype, args...); copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float16: case float16:
copy<float16_t>(src, dst, ctype, args...); copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float32: case float32:
copy<float>(src, dst, ctype, args...); copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case bfloat16: case bfloat16:
copy<bfloat16_t>(src, dst, ctype, args...); copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case complex64: case complex64:
copy<complex64_t>(src, dst, ctype, args...); copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
} }
} }