NumberOfElements for shapeless compile and vmap fixes (#802)

This commit is contained in:
Angelos Katharopoulos 2024-03-13 10:34:14 -07:00 committed by GitHub
parent 29d0c10ee5
commit 76c919b4ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 289 additions and 72 deletions

View File

@ -38,6 +38,7 @@ DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod) DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal) DEFAULT(Equal)
DEFAULT(Erf) DEFAULT(Erf)
DEFAULT(ErfInv) DEFAULT(ErfInv)

View File

@ -51,6 +51,7 @@ DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT(Divide) DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder) DEFAULT(Remainder)
DEFAULT(Equal) DEFAULT(Equal)
DEFAULT(Erf) DEFAULT(Erf)

View File

@ -251,6 +251,62 @@ void Depends::eval(
} }
} }
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) { void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];

View File

@ -696,6 +696,10 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min"); binary_op(inputs, out, "min");
} }
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) { void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "floor"); unary_op(inputs, out, "floor");
} }

View File

@ -43,6 +43,7 @@ NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends) NO_GPU_MULTI(Depends)
NO_GPU(Divide) NO_GPU(Divide)
NO_GPU_MULTI(DivMod) NO_GPU_MULTI(DivMod)
NO_GPU(NumberOfElements)
NO_GPU(Remainder) NO_GPU(Remainder)
NO_GPU(Equal) NO_GPU(Equal)
NO_GPU(Erf) NO_GPU(Erf)

View File

@ -74,7 +74,7 @@ bool allows_shapeless(const Primitive& p) {
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) || is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) || typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) || typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
typeid(p) == typeid(Select); typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);
} }
Compiled::Compiled( Compiled::Compiled(

View File

@ -46,14 +46,6 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
return {out_shape, sorted_axes}; return {out_shape, sorted_axes};
} }
int compute_number_of_elements(const array& a, const std::vector<int>& axes) {
int nelements = 1;
for (auto axis : axes) {
nelements *= a.shape(axis);
}
return nelements;
}
Dtype at_least_float(const Dtype& d) { Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32); return is_floating_point(d) ? d : promote_types(d, float32);
} }
@ -1356,9 +1348,9 @@ array mean(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
} }
auto nelements = compute_number_of_elements(a, axes);
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s); auto normalizer = number_of_elements(a, axes, true, dtype, s);
return multiply(sum(a, axes, keepdims, s), normalizer, s);
} }
array mean( array mean(
@ -1391,9 +1383,12 @@ array var(
auto v = subtract(a2, mu2, s); auto v = subtract(a2, mu2, s);
if (ddof != 0) { if (ddof != 0) {
auto nelements = compute_number_of_elements(a, axes); auto nelements = number_of_elements(a, axes, false, dtype, s);
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0)); auto factor = divide(
v = multiply(v, array(factor, dtype), s); nelements,
maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s),
s);
v = multiply(v, factor, s);
} }
return v; return v;
@ -1770,7 +1765,7 @@ array logsumexp(
const std::vector<int>& axes, const std::vector<int>& axes,
bool keepdims /* = false */, bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
auto maxval = stop_gradient(max(a, axes, true, s)); auto maxval = stop_gradient(max(a, axes, true, s), s);
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s); auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
out = add(out, reshape(maxval, out.shape(), s), s); out = add(out, reshape(maxval, out.shape(), s), s);
if (!keepdims) { if (!keepdims) {
@ -3600,4 +3595,29 @@ std::vector<array> atleast_3d(
return out; return out;
} }
array number_of_elements(
const array& a,
std::vector<int> axes,
bool inverted,
Dtype dtype /* = int32 */,
StreamOrDevice s /* = {} */) {
for (auto& ax : axes) {
int normal_axis = (ax + a.ndim()) % a.ndim();
if (normal_axis >= a.ndim() || normal_axis < 0) {
std::ostringstream msg;
msg << "[number_of_elements] Can't get the shape for axis " << ax
<< " from an array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
ax = normal_axis;
}
return stop_gradient(array(
std::vector<int>{},
dtype,
std::make_unique<NumberOfElements>(
to_stream(s), std::move(axes), inverted, dtype),
{a}));
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1173,4 +1173,15 @@ std::vector<array> atleast_3d(
const std::vector<array>& a, const std::vector<array>& a,
StreamOrDevice s = {}); StreamOrDevice s = {});
/**
* Extract the number of elements along some axes as a scalar array. Used to
* allow shape dependent shapeless compilation (pun intended).
*/
array number_of_elements(
const array& a,
std::vector<int> axes,
bool inverted,
Dtype dtype = int32,
StreamOrDevice s = {});
} // namespace mlx::core } // namespace mlx::core

View File

@ -23,6 +23,10 @@ std::tuple<array, array, int> vmap_binary_op(
assert(inputs.size() == 2); assert(inputs.size() == 2);
assert(axes.size() == 2); assert(axes.size() == 2);
if (axes[0] == -1 && axes[1] == -1) {
return {inputs[0], inputs[1], -1};
}
auto a = inputs[0]; auto a = inputs[0];
auto b = inputs[1]; auto b = inputs[1];
int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1)); int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));
@ -55,6 +59,10 @@ std::tuple<array, array, array, int> vmap_ternary_op(
assert(inputs.size() == 3); assert(inputs.size() == 3);
assert(axes.size() == 3); assert(axes.size() == 3);
if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) {
return {inputs[0], inputs[1], inputs[2], -1};
}
auto a = inputs[0]; auto a = inputs[0];
auto b = inputs[1]; auto b = inputs[1];
auto c = inputs[2]; auto c = inputs[2];
@ -403,8 +411,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(axes.size() == 1); assert(axes.size() == 1);
return { int axis_left = axes[0] >= 0 && axes[0] <= axis_;
{argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};
} }
bool ArgPartition::is_equivalent(const Primitive& other) const { bool ArgPartition::is_equivalent(const Primitive& other) const {
@ -420,7 +428,7 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap( std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
int reduce_ax = axis_ + (axis_ >= axes[0]); int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]);
auto& in = inputs[0]; auto& in = inputs[0];
std::vector<array> out; std::vector<array> out;
if (reduce_type_ == ArgReduce::ArgMin) { if (reduce_type_ == ArgReduce::ArgMin) {
@ -437,7 +445,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(axes.size() == 1); assert(axes.size() == 1);
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
} }
std::vector<std::vector<int>> ArgReduce::output_shapes( std::vector<std::vector<int>> ArgReduce::output_shapes(
@ -563,13 +572,16 @@ std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(axes.size() == 1); assert(axes.size() == 1);
auto ax = axes[0]; auto ax = axes[0];
auto in_shape = inputs[0].shape(); auto in = inputs[0];
int diff = shape_.size() - inputs[0].ndim() + 1; if (ax >= 0) {
auto in_shape = in.shape();
int diff = shape_.size() - in.ndim() + 1;
assert(diff >= 0); assert(diff >= 0);
in_shape.insert(in_shape.begin(), diff, 1); in_shape.insert(in_shape.begin(), diff, 1);
ax += diff; ax += diff;
shape_.insert(shape_.begin() + ax, in_shape[ax]); shape_.insert(shape_.begin() + ax, in_shape[ax]);
auto in = reshape(inputs[0], in_shape, stream()); in = reshape(in, in_shape, stream());
}
return {{broadcast_to(in, shape_, stream())}, {ax}}; return {{broadcast_to(in, shape_, stream())}, {ax}};
} }
@ -653,15 +665,20 @@ std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
std::vector<array> t_inputs; std::vector<array> t_inputs;
int out_ax = -1;
// Find the first vmapped input // Find the first vmapped input
int i = 0; int i = 0;
for (; i < axes.size(); i++) { for (; i < axes.size(); i++) {
t_inputs.push_back(inputs[i]); t_inputs.push_back(inputs[i]);
if (axes[i] >= 0) { if (axes[i] >= 0) {
out_ax = axes[i];
break; break;
} }
} }
auto out_ax = axes[i++]; if (out_ax >= 0) {
// Advance to the next input
i++;
// Move vmap axes to the same spot. // Move vmap axes to the same spot.
for (; i < axes.size(); ++i) { for (; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) { if (out_ax != axes[i] && axes[i] >= 0) {
@ -670,7 +687,8 @@ std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
t_inputs.push_back(inputs[i]); t_inputs.push_back(inputs[i]);
} }
} }
auto axis = axis_ + (axis_ >= out_ax); }
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
return {{concatenate(t_inputs, axis, stream())}, {out_ax}}; return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
} }
@ -1210,6 +1228,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
int ax = axes[0]; int ax = axes[0];
auto fft_axes = axes_; auto fft_axes = axes_;
auto out_shape = in.shape(); auto out_shape = in.shape();
if (ax >= 0) {
for (auto& fft_ax : fft_axes) { for (auto& fft_ax : fft_axes) {
if (fft_ax >= ax) { if (fft_ax >= ax) {
fft_ax++; fft_ax++;
@ -1219,6 +1238,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1; out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
} }
} }
}
return { return {
{array( {array(
out_shape, out_shape,
@ -2064,7 +2084,8 @@ std::pair<std::vector<array>, std::vector<int>> Partition::vmap(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(axes.size() == 1); assert(axes.size() == 1);
return {{partition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{partition(inputs[0], axis_ + axis_left, stream())}, axes};
} }
bool Partition::is_equivalent(const Primitive& other) const { bool Partition::is_equivalent(const Primitive& other) const {
@ -2185,7 +2206,9 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
} }
auto shape = shape_; auto shape = shape_;
if (kax >= 0) {
shape.insert(shape.begin() + kax, key.shape()[kax]); shape.insert(shape.begin() + kax, key.shape()[kax]);
}
auto get_dtype = [width = width_]() { auto get_dtype = [width = width_]() {
switch (width) { switch (width) {
@ -2217,6 +2240,7 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
// Transpose the input so that the vmap dim is first. // Transpose the input so that the vmap dim is first.
auto& in = inputs[0]; auto& in = inputs[0];
auto ax = axes[0]; auto ax = axes[0];
if (ax >= 0) {
std::vector<int> reorder(in.ndim()); std::vector<int> reorder(in.ndim());
std::iota(reorder.begin(), reorder.end(), 0); std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + ax); reorder.erase(reorder.begin() + ax);
@ -2226,6 +2250,9 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
shape_.insert(shape_.begin(), in.shape()[ax]); shape_.insert(shape_.begin(), in.shape()[ax]);
// Reshape the transposed input to the new shape. // Reshape the transposed input to the new shape.
return {{reshape(out, shape_, stream())}, {0}}; return {{reshape(out, shape_, stream())}, {0}};
} else {
return {{reshape(in, shape_, stream())}, {ax}};
}
} }
std::vector<array> Reshape::vjp( std::vector<array> Reshape::vjp(
@ -2349,11 +2376,13 @@ std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto ax = axes[0]; auto ax = axes[0];
auto reduce_axes = axes_; auto reduce_axes = axes_;
if (ax >= 0) {
for (auto& rax : reduce_axes) { for (auto& rax : reduce_axes) {
if (rax >= ax) { if (rax >= ax) {
rax++; rax++;
} }
} }
}
auto& in = inputs[0]; auto& in = inputs[0];
std::vector<array> out; std::vector<array> out;
switch (reduce_type_) { switch (reduce_type_) {
@ -2424,16 +2453,13 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
auto& in = inputs[0]; auto& in = inputs[0];
auto out_dtype = auto out_dtype =
(in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype(); (in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype();
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return { return {
{array( {array(
in.shape(), in.shape(),
out_dtype, out_dtype,
std::make_unique<Scan>( std::make_unique<Scan>(
stream(), stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
reduce_type_,
axis_ + (axes[0] <= axis_),
reverse_,
inclusive_),
{in})}, {in})},
axes}; axes};
} }
@ -2698,9 +2724,11 @@ std::pair<std::vector<array>, std::vector<int>> Slice::vmap(
auto strides = strides_; auto strides = strides_;
auto ax = axes[0]; auto ax = axes[0];
auto& input = inputs[0]; auto& input = inputs[0];
if (ax >= 0) {
start.insert(start.begin() + ax, 0); start.insert(start.begin() + ax, 0);
stop.insert(stop.begin() + ax, input.shape(ax)); stop.insert(stop.begin() + ax, input.shape(ax));
strides.insert(strides.begin() + ax, 1); strides.insert(strides.begin() + ax, 1);
}
return {{slice(input, start, stop, strides, stream())}, {ax}}; return {{slice(input, start, stop, strides, stream())}, {ax}};
} }
@ -2796,7 +2824,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
// We are vectorizing over an axis other than the last one so keep the // We are vectorizing over an axis other than the last one so keep the
// softmax axis unchanged // softmax axis unchanged
if (axes[0] < inputs[0].ndim() - 1) { if (axes[0] >= 0 && axes[0] < inputs[0].ndim() - 1) {
softmax_axes.push_back(-1); softmax_axes.push_back(-1);
} else { } else {
softmax_axes.push_back(-2); softmax_axes.push_back(-2);
@ -2837,7 +2865,8 @@ std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(axes.size() == 1); assert(axes.size() == 1);
return {{sort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{sort(inputs[0], axis_ + axis_left, stream())}, axes};
} }
std::vector<array> Sort::vjp( std::vector<array> Sort::vjp(
@ -2867,8 +2896,8 @@ bool Sort::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> Split::vmap( std::pair<std::vector<array>, std::vector<int>> Split::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
return { int axis_left = axes[0] >= 0 && axes[0] <= axis_;
{split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes}; return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes};
} }
std::vector<array> Split::vjp( std::vector<array> Split::vjp(
@ -2971,7 +3000,7 @@ bool Sqrt::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap( std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
return {inputs, axes}; return {{stop_gradient(inputs[0], stream())}, axes};
}; };
std::vector<array> Subtract::vjp( std::vector<array> Subtract::vjp(
@ -3093,12 +3122,14 @@ std::pair<std::vector<array>, std::vector<int>> Transpose::vmap(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(axes.size() == 1); assert(axes.size() == 1);
auto vdim = axes[0]; auto vdim = axes[0];
if (vdim >= 0) {
for (auto& dim : axes_) { for (auto& dim : axes_) {
if (dim >= vdim) { if (dim >= vdim) {
dim++; dim++;
} }
} }
axes_.insert(axes_.begin() + vdim, vdim); axes_.insert(axes_.begin() + vdim, vdim);
}
return {{transpose(inputs[0], axes_, stream())}, {vdim}}; return {{transpose(inputs[0], axes_, stream())}, {vdim}};
} }
@ -3107,4 +3138,35 @@ bool Transpose::is_equivalent(const Primitive& other) const {
return axes_ == t_other.axes_; return axes_ == t_other.axes_;
} }
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
std::vector<int> new_axes = axes_;
auto vdim = axes[0];
if (vdim >= 0) {
for (auto& dim : new_axes) {
if (dim >= vdim) {
dim++;
}
}
}
array out = array(
std::vector<int>{},
dtype_,
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
inputs);
return {{out}, {-1}};
}
bool NumberOfElements::is_equivalent(const Primitive& other) const {
const NumberOfElements& n_other = static_cast<const NumberOfElements&>(other);
return axes_ == n_other.axes_ && inverted_ == n_other.inverted_ &&
dtype_ == n_other.dtype_;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1229,6 +1229,37 @@ class NotEqual : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class NumberOfElements : public UnaryPrimitive {
public:
explicit NumberOfElements(
Stream stream,
std::vector<int> axes,
bool inverted,
Dtype dtype)
: UnaryPrimitive(stream),
axes_(std::move(axes)),
inverted_(inverted),
dtype_(dtype) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(NumberOfElements)
bool is_equivalent(const Primitive& other) const override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
return {{}};
}
private:
std::vector<int> axes_;
bool inverted_;
Dtype dtype_;
void eval(const std::vector<array>& inputs, array& out);
};
class Pad : public UnaryPrimitive { class Pad : public UnaryPrimitive {
public: public:
explicit Pad( explicit Pad(

View File

@ -653,6 +653,7 @@ std::vector<array> vmap_replace(
v_axes.push_back(-1); v_axes.push_back(-1);
} }
} }
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes); auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
// For each primitive's outputs add its id, the vout id and the vax // For each primitive's outputs add its id, the vout id and the vax
auto outputs = a.outputs(); auto outputs = a.outputs();

View File

@ -653,6 +653,24 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(expected[0], out[0])) self.assertTrue(mx.allclose(expected[0], out[0]))
self.assertTrue(mx.allclose(expected[1], out[1])) self.assertTrue(mx.allclose(expected[1], out[1]))
def test_shapeless_mean(self):
def mean(x):
return mx.mean(x, keepdims=True)
cmean = mx.compile(mean, shapeless=True)
x = mx.ones(2)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
x = mx.ones(4)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
x = mx.ones(7)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase):
expected = mx.array([2, 1]) expected = mx.array([2, 1])
self.assertTrue(mx.array_equal(out, expected)) self.assertTrue(mx.array_equal(out, expected))
def test_vmap_mean(self):
a = mx.arange(8).reshape(2, 4)
out = mx.vmap(mx.mean)(a)
expected = mx.mean(a, axis=1)
self.assertTrue(mx.allclose(out, expected))
a = mx.arange(16).reshape(2, 2, 4)
out = mx.vmap(mx.vmap(mx.mean))(a)
expected = mx.mean(a, axis=2)
self.assertTrue(mx.allclose(out, expected))
def test_mismatch_input_sizes(self): def test_mismatch_input_sizes(self):
a = mx.ones((10, 1)) a = mx.ones((10, 1))
b = mx.ones((1, 1, 1, 5)) b = mx.ones((1, 1, 1, 5))