mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h
2024-10-30 19:30:28 -07:00

96 lines
2.7 KiB
C++

// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/utils/type_traits.h"
#pragma METAL internals : enable
namespace mlx {
namespace steel {
///////////////////////////////////////////////////////////////////////////////
// Integral constant with casting
///////////////////////////////////////////////////////////////////////////////
template <typename T, T v>
struct integral_constant {
static constexpr constant T value = v;
using value_type = T;
using type = integral_constant;
METAL_FUNC constexpr operator value_type() const noexcept {
return value;
}
// METAL_FUNC constexpr value_type operator()() const noexcept {
// return value;
// }
};
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
template <class T>
struct is_integral : bool_constant<metal::is_integral<T>::value> {};
template <class T, T v>
struct is_integral<integral_constant<T, v>>
: bool_constant<metal::is_integral<T>::value> {};
template <typename T>
constexpr constant bool is_integral_v = is_integral<T>::value;
template <int val>
using Int = integral_constant<int, val>;
///////////////////////////////////////////////////////////////////////////////
// Binary Operators on Integral constants
///////////////////////////////////////////////////////////////////////////////
#define integral_const_binop(__op__, __operator__) \
template <typename T, T tv, typename U, U uv> \
METAL_FUNC constexpr auto __operator__( \
integral_constant<T, tv>, integral_constant<U, uv>) { \
constexpr auto res = tv __op__ uv; \
return integral_constant<decltype(res), res>{}; \
}
integral_const_binop(+, operator+);
integral_const_binop(-, operator-);
integral_const_binop(*, operator*);
integral_const_binop(/, operator/);
integral_const_binop(==, operator==);
integral_const_binop(!=, operator!=);
integral_const_binop(<, operator<);
integral_const_binop(>, operator>);
integral_const_binop(<=, operator<=);
integral_const_binop(>=, operator>=);
integral_const_binop(&&, operator&&);
integral_const_binop(||, operator||);
#undef integral_const_binop
///////////////////////////////////////////////////////////////////////////////
// Reduction operators
///////////////////////////////////////////////////////////////////////////////
template <typename T>
METAL_FUNC constexpr T sum(T x) {
return x;
}
template <typename T, typename... Us>
METAL_FUNC constexpr auto sum(T x, Us... us) {
return x + sum(us...);
}
} // namespace steel
} // namespace mlx
#pragma METAL internals : disable