mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Gather mm new kernel and small refactoring (#2040)
This commit is contained in:
committed by
GitHub
parent
e9e268336b
commit
99eefd2ec0
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -58,14 +60,27 @@ inline void debug_set_primitive_buffer_label(
|
||||
|
||||
std::string get_primitive_string(Primitive* primitive);
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
|
||||
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
|
||||
!std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
|
||||
|
||||
template <typename T>
|
||||
void concatenate(std::string& acc, T first) {
|
||||
acc += first;
|
||||
if constexpr (is_numeric_except_char<T>) {
|
||||
acc += std::to_string(first);
|
||||
} else {
|
||||
acc += first;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void concatenate(std::string& acc, T first, Args... args) {
|
||||
acc += first;
|
||||
if constexpr (is_numeric_except_char<T>) {
|
||||
acc += std::to_string(first);
|
||||
} else {
|
||||
acc += first;
|
||||
}
|
||||
concatenate(acc, args...);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user