mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
NumberOfElements for shapeless compile and vmap fixes (#802)
This commit is contained in:
parent
29d0c10ee5
commit
76c919b4ec
@ -38,6 +38,7 @@ DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
|
@ -51,6 +51,7 @@ DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
|
@ -251,6 +251,62 @@ void Depends::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
numel *= inputs[0].shape(ax);
|
||||
}
|
||||
|
||||
if (inverted_) {
|
||||
numel = 1.0 / numel;
|
||||
}
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
*out.data<bool>() = static_cast<bool>(numel);
|
||||
break;
|
||||
case uint8:
|
||||
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||
break;
|
||||
case uint16:
|
||||
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||
break;
|
||||
case uint32:
|
||||
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||
break;
|
||||
case uint64:
|
||||
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||
break;
|
||||
case int8:
|
||||
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||
break;
|
||||
case int16:
|
||||
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||
break;
|
||||
case int32:
|
||||
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||
break;
|
||||
case int64:
|
||||
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||
break;
|
||||
case float16:
|
||||
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||
break;
|
||||
case float32:
|
||||
*out.data<float>() = static_cast<float>(numel);
|
||||
break;
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
@ -696,6 +696,10 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "floor");
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ NO_GPU_MULTI(CustomVJP)
|
||||
NO_GPU_MULTI(Depends)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(NumberOfElements)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
|
@ -74,7 +74,7 @@ bool allows_shapeless(const Primitive& p) {
|
||||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
||||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
|
||||
typeid(p) == typeid(Select);
|
||||
typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
|
48
mlx/ops.cpp
48
mlx/ops.cpp
@ -46,14 +46,6 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
return {out_shape, sorted_axes};
|
||||
}
|
||||
|
||||
int compute_number_of_elements(const array& a, const std::vector<int>& axes) {
|
||||
int nelements = 1;
|
||||
for (auto axis : axes) {
|
||||
nelements *= a.shape(axis);
|
||||
}
|
||||
return nelements;
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
}
|
||||
@ -1356,9 +1348,9 @@ array mean(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
auto nelements = compute_number_of_elements(a, axes);
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s);
|
||||
auto normalizer = number_of_elements(a, axes, true, dtype, s);
|
||||
return multiply(sum(a, axes, keepdims, s), normalizer, s);
|
||||
}
|
||||
|
||||
array mean(
|
||||
@ -1391,9 +1383,12 @@ array var(
|
||||
auto v = subtract(a2, mu2, s);
|
||||
|
||||
if (ddof != 0) {
|
||||
auto nelements = compute_number_of_elements(a, axes);
|
||||
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0));
|
||||
v = multiply(v, array(factor, dtype), s);
|
||||
auto nelements = number_of_elements(a, axes, false, dtype, s);
|
||||
auto factor = divide(
|
||||
nelements,
|
||||
maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s),
|
||||
s);
|
||||
v = multiply(v, factor, s);
|
||||
}
|
||||
|
||||
return v;
|
||||
@ -1770,7 +1765,7 @@ array logsumexp(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
auto maxval = stop_gradient(max(a, axes, true, s));
|
||||
auto maxval = stop_gradient(max(a, axes, true, s), s);
|
||||
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
|
||||
out = add(out, reshape(maxval, out.shape(), s), s);
|
||||
if (!keepdims) {
|
||||
@ -3600,4 +3595,29 @@ std::vector<array> atleast_3d(
|
||||
return out;
|
||||
}
|
||||
|
||||
array number_of_elements(
|
||||
const array& a,
|
||||
std::vector<int> axes,
|
||||
bool inverted,
|
||||
Dtype dtype /* = int32 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
for (auto& ax : axes) {
|
||||
int normal_axis = (ax + a.ndim()) % a.ndim();
|
||||
if (normal_axis >= a.ndim() || normal_axis < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[number_of_elements] Can't get the shape for axis " << ax
|
||||
<< " from an array with " << a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
ax = normal_axis;
|
||||
}
|
||||
|
||||
return stop_gradient(array(
|
||||
std::vector<int>{},
|
||||
dtype,
|
||||
std::make_unique<NumberOfElements>(
|
||||
to_stream(s), std::move(axes), inverted, dtype),
|
||||
{a}));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
11
mlx/ops.h
11
mlx/ops.h
@ -1173,4 +1173,15 @@ std::vector<array> atleast_3d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Extract the number of elements along some axes as a scalar array. Used to
|
||||
* allow shape dependent shapeless compilation (pun intended).
|
||||
*/
|
||||
array number_of_elements(
|
||||
const array& a,
|
||||
std::vector<int> axes,
|
||||
bool inverted,
|
||||
Dtype dtype = int32,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -23,6 +23,10 @@ std::tuple<array, array, int> vmap_binary_op(
|
||||
assert(inputs.size() == 2);
|
||||
assert(axes.size() == 2);
|
||||
|
||||
if (axes[0] == -1 && axes[1] == -1) {
|
||||
return {inputs[0], inputs[1], -1};
|
||||
}
|
||||
|
||||
auto a = inputs[0];
|
||||
auto b = inputs[1];
|
||||
int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));
|
||||
@ -55,6 +59,10 @@ std::tuple<array, array, array, int> vmap_ternary_op(
|
||||
assert(inputs.size() == 3);
|
||||
assert(axes.size() == 3);
|
||||
|
||||
if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) {
|
||||
return {inputs[0], inputs[1], inputs[2], -1};
|
||||
}
|
||||
|
||||
auto a = inputs[0];
|
||||
auto b = inputs[1];
|
||||
auto c = inputs[2];
|
||||
@ -403,8 +411,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
|
||||
return {
|
||||
{argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
||||
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||
return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||
}
|
||||
|
||||
bool ArgPartition::is_equivalent(const Primitive& other) const {
|
||||
@ -420,7 +428,7 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
|
||||
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
int reduce_ax = axis_ + (axis_ >= axes[0]);
|
||||
int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]);
|
||||
auto& in = inputs[0];
|
||||
std::vector<array> out;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
@ -437,7 +445,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
|
||||
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
||||
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> ArgReduce::output_shapes(
|
||||
@ -563,13 +572,16 @@ std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
auto ax = axes[0];
|
||||
auto in_shape = inputs[0].shape();
|
||||
int diff = shape_.size() - inputs[0].ndim() + 1;
|
||||
assert(diff >= 0);
|
||||
in_shape.insert(in_shape.begin(), diff, 1);
|
||||
ax += diff;
|
||||
shape_.insert(shape_.begin() + ax, in_shape[ax]);
|
||||
auto in = reshape(inputs[0], in_shape, stream());
|
||||
auto in = inputs[0];
|
||||
if (ax >= 0) {
|
||||
auto in_shape = in.shape();
|
||||
int diff = shape_.size() - in.ndim() + 1;
|
||||
assert(diff >= 0);
|
||||
in_shape.insert(in_shape.begin(), diff, 1);
|
||||
ax += diff;
|
||||
shape_.insert(shape_.begin() + ax, in_shape[ax]);
|
||||
in = reshape(in, in_shape, stream());
|
||||
}
|
||||
return {{broadcast_to(in, shape_, stream())}, {ax}};
|
||||
}
|
||||
|
||||
@ -653,24 +665,30 @@ std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<array> t_inputs;
|
||||
int out_ax = -1;
|
||||
// Find the first vmapped input
|
||||
int i = 0;
|
||||
for (; i < axes.size(); i++) {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
if (axes[i] >= 0) {
|
||||
out_ax = axes[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto out_ax = axes[i++];
|
||||
// Move vmap axes to the same spot.
|
||||
for (; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
||||
} else {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
if (out_ax >= 0) {
|
||||
// Advance to the next input
|
||||
i++;
|
||||
|
||||
// Move vmap axes to the same spot.
|
||||
for (; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
||||
} else {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto axis = axis_ + (axis_ >= out_ax);
|
||||
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
|
||||
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
@ -1210,13 +1228,15 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
||||
int ax = axes[0];
|
||||
auto fft_axes = axes_;
|
||||
auto out_shape = in.shape();
|
||||
for (auto& fft_ax : fft_axes) {
|
||||
if (fft_ax >= ax) {
|
||||
fft_ax++;
|
||||
}
|
||||
if (real_) {
|
||||
auto n = out_shape[fft_ax];
|
||||
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
|
||||
if (ax >= 0) {
|
||||
for (auto& fft_ax : fft_axes) {
|
||||
if (fft_ax >= ax) {
|
||||
fft_ax++;
|
||||
}
|
||||
if (real_) {
|
||||
auto n = out_shape[fft_ax];
|
||||
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
@ -2064,7 +2084,8 @@ std::pair<std::vector<array>, std::vector<int>> Partition::vmap(
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
|
||||
return {{partition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
||||
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||
return {{partition(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||
}
|
||||
|
||||
bool Partition::is_equivalent(const Primitive& other) const {
|
||||
@ -2185,7 +2206,9 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
|
||||
}
|
||||
|
||||
auto shape = shape_;
|
||||
shape.insert(shape.begin() + kax, key.shape()[kax]);
|
||||
if (kax >= 0) {
|
||||
shape.insert(shape.begin() + kax, key.shape()[kax]);
|
||||
}
|
||||
|
||||
auto get_dtype = [width = width_]() {
|
||||
switch (width) {
|
||||
@ -2217,15 +2240,19 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
|
||||
// Transpose the input so that the vmap dim is first.
|
||||
auto& in = inputs[0];
|
||||
auto ax = axes[0];
|
||||
std::vector<int> reorder(in.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
reorder.erase(reorder.begin() + ax);
|
||||
reorder.insert(reorder.begin(), ax);
|
||||
// Insert the vmap dim into the shape at the beginning.
|
||||
auto out = transpose(in, reorder, stream());
|
||||
shape_.insert(shape_.begin(), in.shape()[ax]);
|
||||
// Reshape the transposed input to the new shape.
|
||||
return {{reshape(out, shape_, stream())}, {0}};
|
||||
if (ax >= 0) {
|
||||
std::vector<int> reorder(in.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
reorder.erase(reorder.begin() + ax);
|
||||
reorder.insert(reorder.begin(), ax);
|
||||
// Insert the vmap dim into the shape at the beginning.
|
||||
auto out = transpose(in, reorder, stream());
|
||||
shape_.insert(shape_.begin(), in.shape()[ax]);
|
||||
// Reshape the transposed input to the new shape.
|
||||
return {{reshape(out, shape_, stream())}, {0}};
|
||||
} else {
|
||||
return {{reshape(in, shape_, stream())}, {ax}};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> Reshape::vjp(
|
||||
@ -2349,9 +2376,11 @@ std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto reduce_axes = axes_;
|
||||
for (auto& rax : reduce_axes) {
|
||||
if (rax >= ax) {
|
||||
rax++;
|
||||
if (ax >= 0) {
|
||||
for (auto& rax : reduce_axes) {
|
||||
if (rax >= ax) {
|
||||
rax++;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
@ -2424,16 +2453,13 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
|
||||
auto& in = inputs[0];
|
||||
auto out_dtype =
|
||||
(in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype();
|
||||
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||
return {
|
||||
{array(
|
||||
in.shape(),
|
||||
out_dtype,
|
||||
std::make_unique<Scan>(
|
||||
stream(),
|
||||
reduce_type_,
|
||||
axis_ + (axes[0] <= axis_),
|
||||
reverse_,
|
||||
inclusive_),
|
||||
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
|
||||
{in})},
|
||||
axes};
|
||||
}
|
||||
@ -2698,9 +2724,11 @@ std::pair<std::vector<array>, std::vector<int>> Slice::vmap(
|
||||
auto strides = strides_;
|
||||
auto ax = axes[0];
|
||||
auto& input = inputs[0];
|
||||
start.insert(start.begin() + ax, 0);
|
||||
stop.insert(stop.begin() + ax, input.shape(ax));
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
if (ax >= 0) {
|
||||
start.insert(start.begin() + ax, 0);
|
||||
stop.insert(stop.begin() + ax, input.shape(ax));
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
return {{slice(input, start, stop, strides, stream())}, {ax}};
|
||||
}
|
||||
|
||||
@ -2796,7 +2824,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
||||
|
||||
// We are vectorizing over an axis other than the last one so keep the
|
||||
// softmax axis unchanged
|
||||
if (axes[0] < inputs[0].ndim() - 1) {
|
||||
if (axes[0] >= 0 && axes[0] < inputs[0].ndim() - 1) {
|
||||
softmax_axes.push_back(-1);
|
||||
} else {
|
||||
softmax_axes.push_back(-2);
|
||||
@ -2837,7 +2865,8 @@ std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
|
||||
return {{sort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
||||
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||
return {{sort(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Sort::vjp(
|
||||
@ -2867,8 +2896,8 @@ bool Sort::is_equivalent(const Primitive& other) const {
|
||||
std::pair<std::vector<array>, std::vector<int>> Split::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {
|
||||
{split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes};
|
||||
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||
return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Split::vjp(
|
||||
@ -2971,7 +3000,7 @@ bool Sqrt::is_equivalent(const Primitive& other) const {
|
||||
std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {inputs, axes};
|
||||
return {{stop_gradient(inputs[0], stream())}, axes};
|
||||
};
|
||||
|
||||
std::vector<array> Subtract::vjp(
|
||||
@ -3093,12 +3122,14 @@ std::pair<std::vector<array>, std::vector<int>> Transpose::vmap(
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
auto vdim = axes[0];
|
||||
for (auto& dim : axes_) {
|
||||
if (dim >= vdim) {
|
||||
dim++;
|
||||
if (vdim >= 0) {
|
||||
for (auto& dim : axes_) {
|
||||
if (dim >= vdim) {
|
||||
dim++;
|
||||
}
|
||||
}
|
||||
axes_.insert(axes_.begin() + vdim, vdim);
|
||||
}
|
||||
axes_.insert(axes_.begin() + vdim, vdim);
|
||||
return {{transpose(inputs[0], axes_, stream())}, {vdim}};
|
||||
}
|
||||
|
||||
@ -3107,4 +3138,35 @@ bool Transpose::is_equivalent(const Primitive& other) const {
|
||||
return axes_ == t_other.axes_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
|
||||
std::vector<int> new_axes = axes_;
|
||||
auto vdim = axes[0];
|
||||
if (vdim >= 0) {
|
||||
for (auto& dim : new_axes) {
|
||||
if (dim >= vdim) {
|
||||
dim++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array out = array(
|
||||
std::vector<int>{},
|
||||
dtype_,
|
||||
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||
inputs);
|
||||
|
||||
return {{out}, {-1}};
|
||||
}
|
||||
|
||||
bool NumberOfElements::is_equivalent(const Primitive& other) const {
|
||||
const NumberOfElements& n_other = static_cast<const NumberOfElements&>(other);
|
||||
return axes_ == n_other.axes_ && inverted_ == n_other.inverted_ &&
|
||||
dtype_ == n_other.dtype_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1229,6 +1229,37 @@ class NotEqual : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class NumberOfElements : public UnaryPrimitive {
|
||||
public:
|
||||
explicit NumberOfElements(
|
||||
Stream stream,
|
||||
std::vector<int> axes,
|
||||
bool inverted,
|
||||
Dtype dtype)
|
||||
: UnaryPrimitive(stream),
|
||||
axes_(std::move(axes)),
|
||||
inverted_(inverted),
|
||||
dtype_(dtype) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(NumberOfElements)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<std::vector<int>> output_shapes(
|
||||
const std::vector<array>& inputs) override {
|
||||
return {{}};
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
bool inverted_;
|
||||
Dtype dtype_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Pad : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Pad(
|
||||
|
@ -653,6 +653,7 @@ std::vector<array> vmap_replace(
|
||||
v_axes.push_back(-1);
|
||||
}
|
||||
}
|
||||
|
||||
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
||||
// For each primitive's outputs add its id, the vout id and the vax
|
||||
auto outputs = a.outputs();
|
||||
|
@ -653,6 +653,24 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(expected[0], out[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], out[1]))
|
||||
|
||||
def test_shapeless_mean(self):
|
||||
def mean(x):
|
||||
return mx.mean(x, keepdims=True)
|
||||
|
||||
cmean = mx.compile(mean, shapeless=True)
|
||||
|
||||
x = mx.ones(2)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
x = mx.ones(4)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
x = mx.ones(7)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([2, 1])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_vmap_mean(self):
|
||||
a = mx.arange(8).reshape(2, 4)
|
||||
out = mx.vmap(mx.mean)(a)
|
||||
expected = mx.mean(a, axis=1)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
a = mx.arange(16).reshape(2, 2, 4)
|
||||
out = mx.vmap(mx.vmap(mx.mean))(a)
|
||||
expected = mx.mean(a, axis=2)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_mismatch_input_sizes(self):
|
||||
a = mx.ones((10, 1))
|
||||
b = mx.ones((1, 1, 1, 5))
|
||||
|
Loading…
Reference in New Issue
Block a user