mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
This commit is contained in:
committed by
GitHub
parent
5fd11c347d
commit
28eac18571
@@ -1,176 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T> T operator()(T x, T y) { return x + y; }
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T> bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real)
|
||||
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T> bool operator()(T x, T y) { return x > y; }
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x >= y; }
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T> bool operator()(T x, T y) { return x < y; }
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x <= y; }
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf) ? maxval :
|
||||
(maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T> T operator()(T x, T y) { return x * y; }
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x != y; }
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct Subtract {
|
||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x && y; };
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x || y; };
|
||||
};
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
|
||||
Reference in New Issue
Block a user