mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Add missing && when forwarding args (#925)
Without the && args would be copied and perfect forwarding won't work.
This commit is contained in:
@@ -272,7 +272,7 @@ inline void copy_general_general(const array& src, array& dst) {
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
@@ -281,54 +281,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst, args...);
|
||||
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst, args...);
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype, args...);
|
||||
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype, args...);
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -338,46 +338,46 @@ inline void copy_inplace_dispatch(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Args... args) {
|
||||
Args&&... args) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype, args...);
|
||||
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype, args...);
|
||||
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype, args...);
|
||||
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype, args...);
|
||||
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype, args...);
|
||||
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype, args...);
|
||||
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype, args...);
|
||||
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype, args...);
|
||||
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype, args...);
|
||||
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype, args...);
|
||||
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, args...);
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, args...);
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype, args...);
|
||||
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user