mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 04:06:39 +08:00
96 lines
2.7 KiB
C++
96 lines
2.7 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <metal_stdlib>
|
|
#include "mlx/backend/metal/kernels/steel/utils/type_traits.h"
|
|
|
|
#pragma METAL internals : enable
|
|
|
|
namespace mlx {
|
|
namespace steel {
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Integral constant with casting
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, T v>
|
|
struct integral_constant {
|
|
static constexpr constant T value = v;
|
|
using value_type = T;
|
|
using type = integral_constant;
|
|
|
|
METAL_FUNC constexpr operator value_type() const noexcept {
|
|
return value;
|
|
}
|
|
|
|
// METAL_FUNC constexpr value_type operator()() const noexcept {
|
|
// return value;
|
|
// }
|
|
};
|
|
|
|
template <bool B>
|
|
using bool_constant = integral_constant<bool, B>;
|
|
using true_type = bool_constant<true>;
|
|
using false_type = bool_constant<false>;
|
|
|
|
template <class T>
|
|
struct is_integral : bool_constant<metal::is_integral<T>::value> {};
|
|
|
|
template <class T, T v>
|
|
struct is_integral<integral_constant<T, v>>
|
|
: bool_constant<metal::is_integral<T>::value> {};
|
|
|
|
template <typename T>
|
|
constexpr constant bool is_integral_v = is_integral<T>::value;
|
|
|
|
template <int val>
|
|
using Int = integral_constant<int, val>;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Binary Operators on Integral constants
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
#define integral_const_binop(__op__, __operator__) \
|
|
template <typename T, T tv, typename U, U uv> \
|
|
METAL_FUNC constexpr auto __operator__( \
|
|
integral_constant<T, tv>, integral_constant<U, uv>) { \
|
|
constexpr auto res = tv __op__ uv; \
|
|
return integral_constant<decltype(res), res>{}; \
|
|
}
|
|
|
|
integral_const_binop(+, operator+);
|
|
integral_const_binop(-, operator-);
|
|
integral_const_binop(*, operator*);
|
|
integral_const_binop(/, operator/);
|
|
|
|
integral_const_binop(==, operator==);
|
|
integral_const_binop(!=, operator!=);
|
|
integral_const_binop(<, operator<);
|
|
integral_const_binop(>, operator>);
|
|
integral_const_binop(<=, operator<=);
|
|
integral_const_binop(>=, operator>=);
|
|
|
|
integral_const_binop(&&, operator&&);
|
|
integral_const_binop(||, operator||);
|
|
|
|
#undef integral_const_binop
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Reduction operators
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T>
|
|
METAL_FUNC constexpr T sum(T x) {
|
|
return x;
|
|
}
|
|
|
|
template <typename T, typename... Us>
|
|
METAL_FUNC constexpr auto sum(T x, Us... us) {
|
|
return x + sum(us...);
|
|
}
|
|
|
|
} // namespace steel
|
|
} // namespace mlx
|
|
|
|
#pragma METAL internals : disable |