mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Reduce vmap + some fixes (#601)
This commit is contained in:
parent
601c6d6aa8
commit
e88e474fd1
77
mlx/ops.cpp
77
mlx/ops.cpp
@ -17,8 +17,7 @@ namespace {
|
|||||||
|
|
||||||
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape) {
|
||||||
bool keepdims) {
|
|
||||||
std::set<int> axes_set;
|
std::set<int> axes_set;
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
for (auto ax : axes) {
|
for (auto ax : axes) {
|
||||||
@ -38,7 +37,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
|||||||
for (int i = 0; i < ndim; ++i) {
|
for (int i = 0; i < ndim; ++i) {
|
||||||
if (axes_set.count(i) == 0) {
|
if (axes_set.count(i) == 0) {
|
||||||
out_shape.push_back(shape[i]);
|
out_shape.push_back(shape[i]);
|
||||||
} else if (keepdims) {
|
} else {
|
||||||
out_shape.push_back(1);
|
out_shape.push_back(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1217,13 +1216,16 @@ array all(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return astype(a, bool_, s);
|
return astype(a, bool_, s);
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes] =
|
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||||
compute_reduce_shape(axes, a.shape(), keepdims);
|
auto out = array(
|
||||||
return array(
|
|
||||||
out_shape,
|
out_shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
|
if (!keepdims) {
|
||||||
|
out = squeeze(out, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
array all(
|
array all(
|
||||||
@ -1248,13 +1250,16 @@ array any(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return astype(a, bool_, s);
|
return astype(a, bool_, s);
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes] =
|
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||||
compute_reduce_shape(axes, a.shape(), keepdims);
|
auto out = array(
|
||||||
return array(
|
|
||||||
out_shape,
|
out_shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
|
if (!keepdims) {
|
||||||
|
out = squeeze(out, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
array any(
|
array any(
|
||||||
@ -1279,14 +1284,17 @@ array sum(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes] =
|
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||||
compute_reduce_shape(axes, a.shape(), keepdims);
|
|
||||||
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
|
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
|
||||||
return array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
|
if (!keepdims) {
|
||||||
|
out = squeeze(out, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
array sum(
|
array sum(
|
||||||
@ -1374,13 +1382,16 @@ array prod(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes] =
|
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||||
compute_reduce_shape(axes, a.shape(), keepdims);
|
auto out = array(
|
||||||
return array(
|
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
|
if (!keepdims) {
|
||||||
|
out = squeeze(out, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
array prod(
|
array prod(
|
||||||
@ -1408,13 +1419,16 @@ array max(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes] =
|
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||||
compute_reduce_shape(axes, a.shape(), keepdims);
|
auto out = array(
|
||||||
return array(
|
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
|
if (!keepdims) {
|
||||||
|
out = squeeze(out, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
array max(
|
array max(
|
||||||
@ -1442,13 +1456,16 @@ array min(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes] =
|
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||||
compute_reduce_shape(axes, a.shape(), keepdims);
|
auto out = array(
|
||||||
return array(
|
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
|
if (!keepdims) {
|
||||||
|
out = squeeze(out, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
array min(
|
array min(
|
||||||
@ -1477,14 +1494,17 @@ array argmin(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[argmin] Cannot argmin reduce zero size array.");
|
"[argmin] Cannot argmin reduce zero size array.");
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes] =
|
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
|
||||||
compute_reduce_shape({axis}, a.shape(), keepdims);
|
auto out = array(
|
||||||
return array(
|
|
||||||
out_shape,
|
out_shape,
|
||||||
uint32,
|
uint32,
|
||||||
std::make_unique<ArgReduce>(
|
std::make_unique<ArgReduce>(
|
||||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||||
{a});
|
{a});
|
||||||
|
if (!keepdims) {
|
||||||
|
out = squeeze(out, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||||
@ -1505,14 +1525,17 @@ array argmax(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[argmax] Cannot argmax reduce zero size array.");
|
"[argmax] Cannot argmax reduce zero size array.");
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes] =
|
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
|
||||||
compute_reduce_shape({axis}, a.shape(), keepdims);
|
auto out = array(
|
||||||
return array(
|
|
||||||
out_shape,
|
out_shape,
|
||||||
uint32,
|
uint32,
|
||||||
std::make_unique<ArgReduce>(
|
std::make_unique<ArgReduce>(
|
||||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||||
{a});
|
{a});
|
||||||
|
if (!keepdims) {
|
||||||
|
out = squeeze(out, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns a sorted copy of the flattened array. */
|
/** Returns a sorted copy of the flattened array. */
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@ -361,6 +360,20 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
|
|||||||
return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_;
|
return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
int reduce_ax = axis_ + (axis_ >= axes[0]);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
std::vector<array> out;
|
||||||
|
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||||
|
out.push_back(argmin(in, reduce_ax, true, stream()));
|
||||||
|
} else {
|
||||||
|
out.push_back(argmax(in, reduce_ax, true, stream()));
|
||||||
|
}
|
||||||
|
return {out, axes};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
@ -2153,7 +2166,36 @@ std::vector<array> Reduce::vjp(
|
|||||||
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
throw std::runtime_error("Reduce::vmap not yet implemented.");
|
auto ax = axes[0];
|
||||||
|
auto reduce_axes = axes_;
|
||||||
|
for (auto& rax : reduce_axes) {
|
||||||
|
if (rax >= ax) {
|
||||||
|
rax++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto& in = inputs[0];
|
||||||
|
std::vector<array> out;
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Reduce::And:
|
||||||
|
out.push_back(all(in, reduce_axes, true, stream()));
|
||||||
|
break;
|
||||||
|
case Reduce::Or:
|
||||||
|
out.push_back(any(in, reduce_axes, true, stream()));
|
||||||
|
break;
|
||||||
|
case Reduce::Sum:
|
||||||
|
out.push_back(sum(in, reduce_axes, true, stream()));
|
||||||
|
break;
|
||||||
|
case Reduce::Prod:
|
||||||
|
out.push_back(prod(in, reduce_axes, true, stream()));
|
||||||
|
break;
|
||||||
|
case Reduce::Min:
|
||||||
|
out.push_back(min(in, reduce_axes, true, stream()));
|
||||||
|
break;
|
||||||
|
case Reduce::Max:
|
||||||
|
out.push_back(max(in, reduce_axes, true, stream()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return {out, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Reduce::is_equivalent(const Primitive& other) const {
|
bool Reduce::is_equivalent(const Primitive& other) const {
|
||||||
|
@ -341,6 +341,7 @@ class ArgReduce : public UnaryPrimitive {
|
|||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(ArgReduce)
|
DEFINE_PRINT(ArgReduce)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
@ -548,9 +548,8 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
|||||||
"[vmap] The number of in axes must match the number of inputs.");
|
"[vmap] The number of in axes must match the number of inputs.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the function on placeholder inputs
|
// Some error checking and get the vmap axis size
|
||||||
// to get the original graph
|
size_t vmap_ax_size;
|
||||||
std::vector<array> s_inputs;
|
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
if (in_axes[i] != -1) {
|
if (in_axes[i] != -1) {
|
||||||
if (inputs[i].ndim() == 0) {
|
if (inputs[i].ndim() == 0) {
|
||||||
@ -563,7 +562,26 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
|||||||
<< inputs[i].ndim() << " dimensions.";
|
<< inputs[i].ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
vmap_ax_size = inputs[i].shape(in_axes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check that all vmapped axes have the same size
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (in_axes[i] != -1) {
|
||||||
|
if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[vmap] Inconsistent axis sizes: " << in_ax << " and "
|
||||||
|
<< vmap_ax_size << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the function on placeholder inputs
|
||||||
|
// to get the original graph
|
||||||
|
std::vector<array> s_inputs;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (in_axes[i] != -1) {
|
||||||
std::vector<int> shape = inputs[i].shape();
|
std::vector<int> shape = inputs[i].shape();
|
||||||
shape.erase(shape.begin() + in_axes[i]);
|
shape.erase(shape.begin() + in_axes[i]);
|
||||||
array in(shape, inputs[i].dtype(), nullptr, {});
|
array in(shape, inputs[i].dtype(), nullptr, {});
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -220,6 +220,50 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.array_equal(out, expected))
|
self.assertTrue(mx.array_equal(out, expected))
|
||||||
|
|
||||||
|
def test_vmap_reduce(self):
|
||||||
|
a = mx.ones((5, 5), mx.int32)
|
||||||
|
out = mx.vmap(lambda x: x.sum())(a)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))
|
||||||
|
|
||||||
|
out = mx.vmap(lambda x: x.sum(keepdims=True))(a)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.full((5, 1), 5)))
|
||||||
|
|
||||||
|
out = mx.vmap(lambda x: x.sum(axis=0))(a)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))
|
||||||
|
|
||||||
|
a = mx.ones((5, 3, 2), mx.int32)
|
||||||
|
out = mx.vmap(lambda x: x.sum(axis=(0, 1)))(a)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.full((5,), 6)))
|
||||||
|
|
||||||
|
a = mx.ones((5, 3, 2), mx.int32)
|
||||||
|
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(1,))(a)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.full((3,), 10)))
|
||||||
|
|
||||||
|
a = mx.ones((5, 3, 2), mx.int32)
|
||||||
|
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(2,))(a)
|
||||||
|
self.assertTrue(mx.array_equal(out, mx.full((2,), 15)))
|
||||||
|
|
||||||
|
def test_vmap_argreduce(self):
|
||||||
|
a = mx.array([[1, 2, 3], [2, 3, 1]])
|
||||||
|
out = mx.vmap(lambda x: mx.argmin(x))(a)
|
||||||
|
expected = mx.array([0, 2])
|
||||||
|
self.assertTrue(mx.array_equal(out, expected))
|
||||||
|
|
||||||
|
out = mx.vmap(lambda x: mx.argmax(x))(a)
|
||||||
|
expected = mx.array([2, 1])
|
||||||
|
self.assertTrue(mx.array_equal(out, expected))
|
||||||
|
|
||||||
|
def test_mismatch_input_sizes(self):
|
||||||
|
a = mx.ones((10, 1))
|
||||||
|
b = mx.ones((1, 1, 1, 5))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out = mx.vmap(lambda x, y: x + y)(a, b)
|
||||||
|
|
||||||
|
b = mx.ones((10, 5))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user