Add missing && when forwarding args (#925)

Without the && args would be copied and perfect forwarding won't work.
This commit is contained in:
Cheng 2024-03-29 22:48:29 +09:00 committed by GitHub
parent d8cb3128f6
commit 913b19329c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 32 deletions

View File

@ -24,7 +24,7 @@
<< std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args... args) {
double time_fn(F fn, Args&&... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));

View File

@ -272,7 +272,7 @@ inline void copy_general_general(const array& src, array& dst) {
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
@ -281,54 +281,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args... args) {
copy_vector<SrcT, DstT>(src, dst);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst, args...);
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, args...);
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
}
}
template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype, args...);
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, args...);
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype, args...);
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype, args...);
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype, args...);
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype, args...);
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype, args...);
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype, args...);
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype, args...);
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype, args...);
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<SrcT, float>(src, dst, ctype, args...);
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype, args...);
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
@ -338,46 +338,46 @@ inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Args... args) {
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, args...);
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, args...);
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, args...);
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, args...);
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, args...);
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, args...);
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, args...);
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, args...);
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, args...);
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, args...);
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, args...);
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, args...);
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, args...);
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}