mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Add missing && when forwarding args (#925)
Without the && args would be copied and perfect forwarding won't work.
This commit is contained in:
parent
d8cb3128f6
commit
913b19329c
@ -24,7 +24,7 @@
|
|||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
template <typename F, typename... Args>
|
template <typename F, typename... Args>
|
||||||
double time_fn(F fn, Args... args) {
|
double time_fn(F fn, Args&&... args) {
|
||||||
// warmup
|
// warmup
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
eval(fn(std::forward<Args>(args)...));
|
eval(fn(std::forward<Args>(args)...));
|
||||||
|
@ -272,7 +272,7 @@ inline void copy_general_general(const array& src, array& dst) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT, typename... Args>
|
template <typename SrcT, typename DstT, typename... Args>
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
copy_single<SrcT, DstT>(src, dst);
|
copy_single<SrcT, DstT>(src, dst);
|
||||||
@ -281,54 +281,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
|||||||
copy_vector<SrcT, DstT>(src, dst);
|
copy_vector<SrcT, DstT>(src, dst);
|
||||||
return;
|
return;
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
copy_general<SrcT, DstT>(src, dst, args...);
|
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||||
return;
|
return;
|
||||||
case CopyType::GeneralGeneral:
|
case CopyType::GeneralGeneral:
|
||||||
copy_general_general<SrcT, DstT>(src, dst, args...);
|
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename... Args>
|
template <typename SrcT, typename... Args>
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||||
switch (dst.dtype()) {
|
switch (dst.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
copy<SrcT, bool>(src, dst, ctype, args...);
|
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case uint8:
|
case uint8:
|
||||||
copy<SrcT, uint8_t>(src, dst, ctype, args...);
|
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case uint16:
|
case uint16:
|
||||||
copy<SrcT, uint16_t>(src, dst, ctype, args...);
|
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case uint32:
|
case uint32:
|
||||||
copy<SrcT, uint32_t>(src, dst, ctype, args...);
|
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case uint64:
|
case uint64:
|
||||||
copy<SrcT, uint64_t>(src, dst, ctype, args...);
|
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
copy<SrcT, int8_t>(src, dst, ctype, args...);
|
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
copy<SrcT, int16_t>(src, dst, ctype, args...);
|
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
copy<SrcT, int32_t>(src, dst, ctype, args...);
|
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
copy<SrcT, int64_t>(src, dst, ctype, args...);
|
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
copy<SrcT, float16_t>(src, dst, ctype, args...);
|
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
copy<SrcT, float>(src, dst, ctype, args...);
|
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
|
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
copy<SrcT, complex64_t>(src, dst, ctype, args...);
|
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -338,46 +338,46 @@ inline void copy_inplace_dispatch(
|
|||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
Args... args) {
|
Args&&... args) {
|
||||||
switch (src.dtype()) {
|
switch (src.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
copy<bool>(src, dst, ctype, args...);
|
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case uint8:
|
case uint8:
|
||||||
copy<uint8_t>(src, dst, ctype, args...);
|
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case uint16:
|
case uint16:
|
||||||
copy<uint16_t>(src, dst, ctype, args...);
|
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case uint32:
|
case uint32:
|
||||||
copy<uint32_t>(src, dst, ctype, args...);
|
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case uint64:
|
case uint64:
|
||||||
copy<uint64_t>(src, dst, ctype, args...);
|
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
copy<int8_t>(src, dst, ctype, args...);
|
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
copy<int16_t>(src, dst, ctype, args...);
|
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
copy<int32_t>(src, dst, ctype, args...);
|
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
copy<int64_t>(src, dst, ctype, args...);
|
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
copy<float16_t>(src, dst, ctype, args...);
|
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
copy<float>(src, dst, ctype, args...);
|
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
copy<bfloat16_t>(src, dst, ctype, args...);
|
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
copy<complex64_t>(src, dst, ctype, args...);
|
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user