// Copyright © 2023-2024 Apple Inc. #pragma once #include #include #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" struct Add { template T operator()(T x, T y) { return x + y; } }; struct Divide { template T operator()(T x, T y) { return x / y; } }; struct Remainder { template metal::enable_if_t & !metal::is_signed_v, T> operator()(T x, T y) { return x % y; } template metal::enable_if_t & metal::is_signed_v, T> operator()(T x, T y) { auto r = x % y; if (r != 0 && (r < 0 != y < 0)) { r += y; } return r; } template metal::enable_if_t, T> operator()(T x, T y) { T r = fmod(x, y); if (r != 0 && (r < 0 != y < 0)) { r += y; } return r; } template <> complex64_t operator()(complex64_t x, complex64_t y) { return x % y; } }; struct Equal { template bool operator()(T x, T y) { return x == y; } }; struct NaNEqual { template bool operator()(T x, T y) { return x == y || (metal::isnan(x) && metal::isnan(y)); } template <> bool operator()(complex64_t x, complex64_t y) { return x == y || (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && metal::isnan(y.imag)) || (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); } }; struct Greater { template bool operator()(T x, T y) { return x > y; } }; struct GreaterEqual { template bool operator()(T x, T y) { return x >= y; } }; struct Less { template bool operator()(T x, T y) { return x < y; } }; struct LessEqual { template bool operator()(T x, T y) { return x <= y; } }; struct LogAddExp { template T operator()(T x, T y) { if (metal::isnan(x) || metal::isnan(y)) { return metal::numeric_limits::quiet_NaN(); } constexpr T inf = metal::numeric_limits::infinity(); T maxval = metal::max(x, y); T minval = metal::min(x, y); return (minval == -inf || maxval == inf) ? maxval : (maxval + log1p(metal::exp(minval - maxval))); }; }; struct Maximum { template metal::enable_if_t, T> operator()(T x, T y) { return metal::max(x, y); } template metal::enable_if_t, T> operator()(T x, T y) { if (metal::isnan(x)) { return x; } return x > y ? x : y; } template <> complex64_t operator()(complex64_t x, complex64_t y) { if (metal::isnan(x.real) || metal::isnan(x.imag)) { return x; } return x > y ? x : y; } }; struct Minimum { template metal::enable_if_t, T> operator()(T x, T y) { return metal::min(x, y); } template metal::enable_if_t, T> operator()(T x, T y) { if (metal::isnan(x)) { return x; } return x < y ? x : y; } template <> complex64_t operator()(complex64_t x, complex64_t y) { if (metal::isnan(x.real) || metal::isnan(x.imag)) { return x; } return x < y ? x : y; } }; struct Multiply { template T operator()(T x, T y) { return x * y; } }; struct NotEqual { template bool operator()(T x, T y) { return x != y; } template <> bool operator()(complex64_t x, complex64_t y) { return x.real != y.real || x.imag != y.imag; } }; struct Power { template metal::enable_if_t, T> operator()(T base, T exp) { return metal::pow(base, exp); } template metal::enable_if_t, T> operator()(T base, T exp) { T res = 1; while (exp) { if (exp & 1) { res *= base; } exp >>= 1; base *= base; } return res; } template <> complex64_t operator()(complex64_t x, complex64_t y) { auto x_theta = metal::atan(x.imag / x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); auto phase = y.imag * x_ln_r + y.real * x_theta; return {mag * metal::cos(phase), mag * metal::sin(phase)}; } }; struct Subtract { template T operator()(T x, T y) { return x - y; } }; struct LogicalAnd { template T operator()(T x, T y) { return x && y; }; }; struct LogicalOr { template T operator()(T x, T y) { return x || y; }; }; struct BitwiseAnd { template T operator()(T x, T y) { return x & y; }; }; struct BitwiseOr { template T operator()(T x, T y) { return x | y; }; }; struct BitwiseXor { template T operator()(T x, T y) { return x ^ y; }; }; struct LeftShift { template T operator()(T x, T y) { return x << y; }; }; struct RightShift { template T operator()(T x, T y) { return x >> y; }; };