NumberOfElements for shapeless compile and vmap fixes (#802)

This commit is contained in:
Angelos Katharopoulos 2024-03-13 10:34:14 -07:00 committed by GitHub
parent 29d0c10ee5
commit 76c919b4ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 289 additions and 72 deletions

View File

@ -38,6 +38,7 @@ DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)

View File

@ -51,6 +51,7 @@ DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)

View File

@ -251,6 +251,62 @@ void Depends::eval(
}
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@ -696,6 +696,10 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "floor");
}

View File

@ -43,6 +43,7 @@ NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(NumberOfElements)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)

View File

@ -74,7 +74,7 @@ bool allows_shapeless(const Primitive& p) {
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
typeid(p) == typeid(Select);
typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);
}
Compiled::Compiled(

View File

@ -46,14 +46,6 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
return {out_shape, sorted_axes};
}
int compute_number_of_elements(const array& a, const std::vector<int>& axes) {
int nelements = 1;
for (auto axis : axes) {
nelements *= a.shape(axis);
}
return nelements;
}
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
}
@ -1356,9 +1348,9 @@ array mean(
throw std::invalid_argument(msg.str());
}
}
auto nelements = compute_number_of_elements(a, axes);
auto dtype = at_least_float(a.dtype());
return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s);
auto normalizer = number_of_elements(a, axes, true, dtype, s);
return multiply(sum(a, axes, keepdims, s), normalizer, s);
}
array mean(
@ -1391,9 +1383,12 @@ array var(
auto v = subtract(a2, mu2, s);
if (ddof != 0) {
auto nelements = compute_number_of_elements(a, axes);
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0));
v = multiply(v, array(factor, dtype), s);
auto nelements = number_of_elements(a, axes, false, dtype, s);
auto factor = divide(
nelements,
maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s),
s);
v = multiply(v, factor, s);
}
return v;
@ -1770,7 +1765,7 @@ array logsumexp(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
auto maxval = stop_gradient(max(a, axes, true, s));
auto maxval = stop_gradient(max(a, axes, true, s), s);
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
out = add(out, reshape(maxval, out.shape(), s), s);
if (!keepdims) {
@ -3600,4 +3595,29 @@ std::vector<array> atleast_3d(
return out;
}
array number_of_elements(
const array& a,
std::vector<int> axes,
bool inverted,
Dtype dtype /* = int32 */,
StreamOrDevice s /* = {} */) {
for (auto& ax : axes) {
int normal_axis = (ax + a.ndim()) % a.ndim();
if (normal_axis >= a.ndim() || normal_axis < 0) {
std::ostringstream msg;
msg << "[number_of_elements] Can't get the shape for axis " << ax
<< " from an array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
ax = normal_axis;
}
return stop_gradient(array(
std::vector<int>{},
dtype,
std::make_unique<NumberOfElements>(
to_stream(s), std::move(axes), inverted, dtype),
{a}));
}
} // namespace mlx::core

View File

@ -1173,4 +1173,15 @@ std::vector<array> atleast_3d(
const std::vector<array>& a,
StreamOrDevice s = {});
/**
* Extract the number of elements along some axes as a scalar array. Used to
* allow shape dependent shapeless compilation (pun intended).
*/
array number_of_elements(
const array& a,
std::vector<int> axes,
bool inverted,
Dtype dtype = int32,
StreamOrDevice s = {});
} // namespace mlx::core

View File

@ -23,6 +23,10 @@ std::tuple<array, array, int> vmap_binary_op(
assert(inputs.size() == 2);
assert(axes.size() == 2);
if (axes[0] == -1 && axes[1] == -1) {
return {inputs[0], inputs[1], -1};
}
auto a = inputs[0];
auto b = inputs[1];
int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));
@ -55,6 +59,10 @@ std::tuple<array, array, array, int> vmap_ternary_op(
assert(inputs.size() == 3);
assert(axes.size() == 3);
if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) {
return {inputs[0], inputs[1], inputs[2], -1};
}
auto a = inputs[0];
auto b = inputs[1];
auto c = inputs[2];
@ -403,8 +411,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {
{argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};
}
bool ArgPartition::is_equivalent(const Primitive& other) const {
@ -420,7 +428,7 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
int reduce_ax = axis_ + (axis_ >= axes[0]);
int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]);
auto& in = inputs[0];
std::vector<array> out;
if (reduce_type_ == ArgReduce::ArgMin) {
@ -437,7 +445,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
}
std::vector<std::vector<int>> ArgReduce::output_shapes(
@ -563,13 +572,16 @@ std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
auto ax = axes[0];
auto in_shape = inputs[0].shape();
int diff = shape_.size() - inputs[0].ndim() + 1;
assert(diff >= 0);
in_shape.insert(in_shape.begin(), diff, 1);
ax += diff;
shape_.insert(shape_.begin() + ax, in_shape[ax]);
auto in = reshape(inputs[0], in_shape, stream());
auto in = inputs[0];
if (ax >= 0) {
auto in_shape = in.shape();
int diff = shape_.size() - in.ndim() + 1;
assert(diff >= 0);
in_shape.insert(in_shape.begin(), diff, 1);
ax += diff;
shape_.insert(shape_.begin() + ax, in_shape[ax]);
in = reshape(in, in_shape, stream());
}
return {{broadcast_to(in, shape_, stream())}, {ax}};
}
@ -653,24 +665,30 @@ std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
std::vector<array> t_inputs;
int out_ax = -1;
// Find the first vmapped input
int i = 0;
for (; i < axes.size(); i++) {
t_inputs.push_back(inputs[i]);
if (axes[i] >= 0) {
out_ax = axes[i];
break;
}
}
auto out_ax = axes[i++];
// Move vmap axes to the same spot.
for (; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
} else {
t_inputs.push_back(inputs[i]);
if (out_ax >= 0) {
// Advance to the next input
i++;
// Move vmap axes to the same spot.
for (; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
} else {
t_inputs.push_back(inputs[i]);
}
}
}
auto axis = axis_ + (axis_ >= out_ax);
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
}
@ -1210,13 +1228,15 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
int ax = axes[0];
auto fft_axes = axes_;
auto out_shape = in.shape();
for (auto& fft_ax : fft_axes) {
if (fft_ax >= ax) {
fft_ax++;
}
if (real_) {
auto n = out_shape[fft_ax];
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
if (ax >= 0) {
for (auto& fft_ax : fft_axes) {
if (fft_ax >= ax) {
fft_ax++;
}
if (real_) {
auto n = out_shape[fft_ax];
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
}
}
}
return {
@ -2064,7 +2084,8 @@ std::pair<std::vector<array>, std::vector<int>> Partition::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{partition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{partition(inputs[0], axis_ + axis_left, stream())}, axes};
}
bool Partition::is_equivalent(const Primitive& other) const {
@ -2185,7 +2206,9 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
}
auto shape = shape_;
shape.insert(shape.begin() + kax, key.shape()[kax]);
if (kax >= 0) {
shape.insert(shape.begin() + kax, key.shape()[kax]);
}
auto get_dtype = [width = width_]() {
switch (width) {
@ -2217,15 +2240,19 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
// Transpose the input so that the vmap dim is first.
auto& in = inputs[0];
auto ax = axes[0];
std::vector<int> reorder(in.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + ax);
reorder.insert(reorder.begin(), ax);
// Insert the vmap dim into the shape at the beginning.
auto out = transpose(in, reorder, stream());
shape_.insert(shape_.begin(), in.shape()[ax]);
// Reshape the transposed input to the new shape.
return {{reshape(out, shape_, stream())}, {0}};
if (ax >= 0) {
std::vector<int> reorder(in.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + ax);
reorder.insert(reorder.begin(), ax);
// Insert the vmap dim into the shape at the beginning.
auto out = transpose(in, reorder, stream());
shape_.insert(shape_.begin(), in.shape()[ax]);
// Reshape the transposed input to the new shape.
return {{reshape(out, shape_, stream())}, {0}};
} else {
return {{reshape(in, shape_, stream())}, {ax}};
}
}
std::vector<array> Reshape::vjp(
@ -2349,9 +2376,11 @@ std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
const std::vector<int>& axes) {
auto ax = axes[0];
auto reduce_axes = axes_;
for (auto& rax : reduce_axes) {
if (rax >= ax) {
rax++;
if (ax >= 0) {
for (auto& rax : reduce_axes) {
if (rax >= ax) {
rax++;
}
}
}
auto& in = inputs[0];
@ -2424,16 +2453,13 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
auto& in = inputs[0];
auto out_dtype =
(in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype();
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {
{array(
in.shape(),
out_dtype,
std::make_unique<Scan>(
stream(),
reduce_type_,
axis_ + (axes[0] <= axis_),
reverse_,
inclusive_),
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
{in})},
axes};
}
@ -2698,9 +2724,11 @@ std::pair<std::vector<array>, std::vector<int>> Slice::vmap(
auto strides = strides_;
auto ax = axes[0];
auto& input = inputs[0];
start.insert(start.begin() + ax, 0);
stop.insert(stop.begin() + ax, input.shape(ax));
strides.insert(strides.begin() + ax, 1);
if (ax >= 0) {
start.insert(start.begin() + ax, 0);
stop.insert(stop.begin() + ax, input.shape(ax));
strides.insert(strides.begin() + ax, 1);
}
return {{slice(input, start, stop, strides, stream())}, {ax}};
}
@ -2796,7 +2824,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
// We are vectorizing over an axis other than the last one so keep the
// softmax axis unchanged
if (axes[0] < inputs[0].ndim() - 1) {
if (axes[0] >= 0 && axes[0] < inputs[0].ndim() - 1) {
softmax_axes.push_back(-1);
} else {
softmax_axes.push_back(-2);
@ -2837,7 +2865,8 @@ std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{sort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{sort(inputs[0], axis_ + axis_left, stream())}, axes};
}
std::vector<array> Sort::vjp(
@ -2867,8 +2896,8 @@ bool Sort::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> Split::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {
{split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes};
}
std::vector<array> Split::vjp(
@ -2971,7 +3000,7 @@ bool Sqrt::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {inputs, axes};
return {{stop_gradient(inputs[0], stream())}, axes};
};
std::vector<array> Subtract::vjp(
@ -3093,12 +3122,14 @@ std::pair<std::vector<array>, std::vector<int>> Transpose::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
auto vdim = axes[0];
for (auto& dim : axes_) {
if (dim >= vdim) {
dim++;
if (vdim >= 0) {
for (auto& dim : axes_) {
if (dim >= vdim) {
dim++;
}
}
axes_.insert(axes_.begin() + vdim, vdim);
}
axes_.insert(axes_.begin() + vdim, vdim);
return {{transpose(inputs[0], axes_, stream())}, {vdim}};
}
@ -3107,4 +3138,35 @@ bool Transpose::is_equivalent(const Primitive& other) const {
return axes_ == t_other.axes_;
}
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
std::vector<int> new_axes = axes_;
auto vdim = axes[0];
if (vdim >= 0) {
for (auto& dim : new_axes) {
if (dim >= vdim) {
dim++;
}
}
}
array out = array(
std::vector<int>{},
dtype_,
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
inputs);
return {{out}, {-1}};
}
bool NumberOfElements::is_equivalent(const Primitive& other) const {
const NumberOfElements& n_other = static_cast<const NumberOfElements&>(other);
return axes_ == n_other.axes_ && inverted_ == n_other.inverted_ &&
dtype_ == n_other.dtype_;
}
} // namespace mlx::core

View File

@ -1229,6 +1229,37 @@ class NotEqual : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class NumberOfElements : public UnaryPrimitive {
public:
explicit NumberOfElements(
Stream stream,
std::vector<int> axes,
bool inverted,
Dtype dtype)
: UnaryPrimitive(stream),
axes_(std::move(axes)),
inverted_(inverted),
dtype_(dtype) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(NumberOfElements)
bool is_equivalent(const Primitive& other) const override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
return {{}};
}
private:
std::vector<int> axes_;
bool inverted_;
Dtype dtype_;
void eval(const std::vector<array>& inputs, array& out);
};
class Pad : public UnaryPrimitive {
public:
explicit Pad(

View File

@ -653,6 +653,7 @@ std::vector<array> vmap_replace(
v_axes.push_back(-1);
}
}
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
// For each primitive's outputs add its id, the vout id and the vax
auto outputs = a.outputs();

View File

@ -653,6 +653,24 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(expected[0], out[0]))
self.assertTrue(mx.allclose(expected[1], out[1]))
def test_shapeless_mean(self):
def mean(x):
return mx.mean(x, keepdims=True)
cmean = mx.compile(mean, shapeless=True)
x = mx.ones(2)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
x = mx.ones(4)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
x = mx.ones(7)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
if __name__ == "__main__":
unittest.main()

View File

@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase):
expected = mx.array([2, 1])
self.assertTrue(mx.array_equal(out, expected))
def test_vmap_mean(self):
a = mx.arange(8).reshape(2, 4)
out = mx.vmap(mx.mean)(a)
expected = mx.mean(a, axis=1)
self.assertTrue(mx.allclose(out, expected))
a = mx.arange(16).reshape(2, 2, 4)
out = mx.vmap(mx.vmap(mx.mean))(a)
expected = mx.mean(a, axis=2)
self.assertTrue(mx.allclose(out, expected))
def test_mismatch_input_sizes(self):
a = mx.ones((10, 1))
b = mx.ones((1, 1, 1, 5))