mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
LogCumSumExp (#2069)
This commit is contained in:
parent
7275ac7523
commit
e9e268336b
@ -38,6 +38,7 @@ Array
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
|
@ -103,6 +103,7 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
@ -226,6 +227,16 @@ void scan_dispatch(
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::LogAddExp: {
|
||||
auto op = [](U a, T b) {
|
||||
return detail::LogAddExp{}(a, static_cast<U>(b));
|
||||
};
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
|
||||
#define DEFINE_SIMD_SCAN() \
|
||||
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
|
||||
T simd_scan(T val) { \
|
||||
@ -139,6 +141,29 @@ struct CumMin {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumLogaddexp {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return LogAddExp{}(a, static_cast<U>(b));
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i = 1; i <= 16; i *= 2) {
|
||||
U other = simd_shuffle_and_fill_up(x, init, i);
|
||||
x = LogAddExp{}(x, other);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_unsafe(U values[N_READS], const device T* input) {
|
||||
if (reverse) {
|
||||
|
@ -101,4 +101,7 @@ instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMi
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on
|
||||
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2)
|
||||
instantiate_scan_helper(logaddexp_float16_float16, half, half, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_float32_float32, float, float, CumLogaddexp, 4)
|
||||
instantiate_scan_helper(logaddexp_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumLogaddexp, 4) // clang-format on
|
||||
|
@ -60,6 +60,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
case Scan::Min:
|
||||
reduce_type = "min";
|
||||
break;
|
||||
case Scan::LogAddExp:
|
||||
reduce_type = "logaddexp";
|
||||
break;
|
||||
}
|
||||
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
|
||||
auto kernel = get_scan_kernel(
|
||||
|
@ -298,7 +298,13 @@ struct PrimitiveFactory {
|
||||
SERIALIZE_PRIMITIVE(Reshape),
|
||||
SERIALIZE_PRIMITIVE(Reduce, "And", "Or", "Sum", "Prod", "Min", "Max"),
|
||||
SERIALIZE_PRIMITIVE(Round),
|
||||
SERIALIZE_PRIMITIVE(Scan, "CumSum", "CumProd", "CumMin", "CumMax"),
|
||||
SERIALIZE_PRIMITIVE(
|
||||
Scan,
|
||||
"CumSum",
|
||||
"CumProd",
|
||||
"CumMin",
|
||||
"CumMax",
|
||||
"CumLogaddexp"),
|
||||
SERIALIZE_PRIMITIVE(Scatter),
|
||||
SERIALIZE_PRIMITIVE(Select),
|
||||
SERIALIZE_PRIMITIVE(Sigmoid),
|
||||
|
22
mlx/ops.cpp
22
mlx/ops.cpp
@ -3504,6 +3504,28 @@ array cummin(
|
||||
{a});
|
||||
}
|
||||
|
||||
array logcumsumexp(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool reverse /* = false*/,
|
||||
bool inclusive /* = true*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
int ndim = a.ndim();
|
||||
if (axis >= ndim || axis < -ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[logcumsumexp] Axis " << axis << " is out of bounds for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
axis = (axis + a.ndim()) % a.ndim();
|
||||
return array(
|
||||
a.shape(),
|
||||
a.dtype(),
|
||||
std::make_shared<Scan>(
|
||||
to_stream(s), Scan::ReduceType::LogAddExp, axis, reverse, inclusive),
|
||||
{a});
|
||||
}
|
||||
|
||||
/** Convolution operations */
|
||||
|
||||
namespace {
|
||||
|
@ -715,6 +715,14 @@ array topk(const array& a, int k, StreamOrDevice s = {});
|
||||
/** Returns topk elements of the array along a given axis. */
|
||||
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Cumulative logsumexp of an array. */
|
||||
array logcumsumexp(
|
||||
const array& a,
|
||||
int axis,
|
||||
bool reverse = false,
|
||||
bool inclusive = true,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** The logsumexp of all elements of the array. */
|
||||
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
|
||||
|
@ -3478,6 +3478,45 @@ std::vector<array> Scan::vjp(
|
||||
|
||||
if (reduce_type_ == Scan::Sum) {
|
||||
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
|
||||
} else if (reduce_type_ == Scan::LogAddExp) {
|
||||
// Ref:
|
||||
// https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
|
||||
|
||||
auto x = primals[0];
|
||||
auto grad = cotangents[0];
|
||||
auto results = outputs[0];
|
||||
|
||||
auto zero = zeros({1}, grad.dtype(), stream());
|
||||
auto grad_min = array(finfo(grad.dtype()).min, grad.dtype());
|
||||
|
||||
// Split the incoming gradient into positive and negative part
|
||||
// in order to take logs. This is required for stable results.
|
||||
auto log_abs_grad = log(abs(grad, stream()), stream());
|
||||
auto log_grad_positive =
|
||||
where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream());
|
||||
auto log_grad_negative =
|
||||
where(less(grad, zero, stream()), log_abs_grad, grad_min, stream());
|
||||
|
||||
auto output_pos = exp(
|
||||
add(logcumsumexp(
|
||||
subtract(log_grad_positive, results, stream()),
|
||||
axis_,
|
||||
!reverse_,
|
||||
inclusive_,
|
||||
stream()),
|
||||
x,
|
||||
stream()));
|
||||
auto output_neg = exp(
|
||||
add(logcumsumexp(
|
||||
subtract(log_grad_negative, results, stream()),
|
||||
axis_,
|
||||
!reverse_,
|
||||
inclusive_,
|
||||
stream()),
|
||||
x,
|
||||
stream()));
|
||||
|
||||
return {subtract(output_pos, output_neg, stream())};
|
||||
} else if (reduce_type_ == Scan::Prod) {
|
||||
auto in = primals[0];
|
||||
// Find the location of the first 0 and set it to 1:
|
||||
|
@ -1728,7 +1728,7 @@ class Round : public UnaryPrimitive {
|
||||
|
||||
class Scan : public UnaryPrimitive {
|
||||
public:
|
||||
enum ReduceType { Max, Min, Sum, Prod };
|
||||
enum ReduceType { Max, Min, Sum, Prod, LogAddExp };
|
||||
|
||||
explicit Scan(
|
||||
Stream stream,
|
||||
@ -1763,6 +1763,9 @@ class Scan : public UnaryPrimitive {
|
||||
case Max:
|
||||
os << "Max";
|
||||
break;
|
||||
case LogAddExp:
|
||||
os << "Logaddexp";
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
@ -1202,6 +1202,28 @@ void init_array(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
"See :func:`max`.")
|
||||
.def(
|
||||
"logcumsumexp",
|
||||
[](const mx::array& a,
|
||||
std::optional<int> axis,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axis) {
|
||||
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
// TODO: Implement that in the C++ API as well. See concatenate
|
||||
// above.
|
||||
return mx::logcumsumexp(
|
||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
"axis"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"reverse"_a = false,
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = nb::none(),
|
||||
"See :func:`logcumsumexp`.")
|
||||
.def(
|
||||
"logsumexp",
|
||||
[](const mx::array& a,
|
||||
|
@ -2382,6 +2382,43 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The output array with the corresponding axes reduced.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"logcumsumexp",
|
||||
[](const mx::array& a,
|
||||
std::optional<int> axis,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axis) {
|
||||
return mx::logcumsumexp(a, *axis, reverse, inclusive, s);
|
||||
} else {
|
||||
return mx::logcumsumexp(
|
||||
mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
"axis"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"reverse"_a = false,
|
||||
"inclusive"_a = true,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return the cumulative logsumexp of the elements along the given axis.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
axis (int, optional): Optional axis to compute the cumulative logsumexp
|
||||
over. If unspecified the cumulative logsumexp of the flattened array is
|
||||
returned.
|
||||
reverse (bool): Perform the cumulative logsumexp in reverse.
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"logsumexp",
|
||||
[](const mx::array& a,
|
||||
|
@ -1508,6 +1508,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
("prod", 1),
|
||||
("min", 1),
|
||||
("max", 1),
|
||||
("logcumsumexp", 1),
|
||||
("logsumexp", 1),
|
||||
("mean", 1),
|
||||
("var", 1),
|
||||
|
@ -1857,6 +1857,30 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
|
||||
self.assertTrue(mx.array_equal(y, x[::-1]))
|
||||
|
||||
def test_logcumsumexp(self):
|
||||
npop = np.logaddexp.accumulate
|
||||
mxop = mx.logcumsumexp
|
||||
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
||||
for axis in (0, 1, 2):
|
||||
c_npy = npop(a_npy, axis=axis)
|
||||
c_mlx = mxop(a_mlx, axis=axis)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
edge_cases_npy = [
|
||||
np.float32([-float("inf")] * 8),
|
||||
np.float32([-float("inf"), 0, -float("inf")]),
|
||||
np.float32([-float("inf"), float("inf"), -float("inf")]),
|
||||
]
|
||||
edge_cases_mlx = [mx.array(a) for a in edge_cases_npy]
|
||||
|
||||
for a_npy, a_mlx in zip(edge_cases_npy, edge_cases_mlx):
|
||||
c_npy = npop(a_npy, axis=0)
|
||||
c_mlx = mxop(a_mlx, axis=0)
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
|
||||
|
||||
def test_scans(self):
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
Loading…
Reference in New Issue
Block a user