mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
NumberOfElements for shapeless compile and vmap fixes (#802)
This commit is contained in:
parent
29d0c10ee5
commit
76c919b4ec
@ -38,6 +38,7 @@ DEFAULT(Copy)
|
|||||||
DEFAULT_MULTI(CustomVJP)
|
DEFAULT_MULTI(CustomVJP)
|
||||||
DEFAULT_MULTI(Depends)
|
DEFAULT_MULTI(Depends)
|
||||||
DEFAULT_MULTI(DivMod)
|
DEFAULT_MULTI(DivMod)
|
||||||
|
DEFAULT(NumberOfElements)
|
||||||
DEFAULT(Equal)
|
DEFAULT(Equal)
|
||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
DEFAULT(ErfInv)
|
DEFAULT(ErfInv)
|
||||||
|
@ -51,6 +51,7 @@ DEFAULT(Cosh)
|
|||||||
DEFAULT_MULTI(CustomVJP)
|
DEFAULT_MULTI(CustomVJP)
|
||||||
DEFAULT_MULTI(Depends)
|
DEFAULT_MULTI(Depends)
|
||||||
DEFAULT(Divide)
|
DEFAULT(Divide)
|
||||||
|
DEFAULT(NumberOfElements)
|
||||||
DEFAULT(Remainder)
|
DEFAULT(Remainder)
|
||||||
DEFAULT(Equal)
|
DEFAULT(Equal)
|
||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
|
@ -251,6 +251,62 @@ void Depends::eval(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
double numel = 1;
|
||||||
|
for (auto ax : axes_) {
|
||||||
|
numel *= inputs[0].shape(ax);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inverted_) {
|
||||||
|
numel = 1.0 / numel;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
*out.data<bool>() = static_cast<bool>(numel);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
*out.data<float>() = static_cast<float>(numel);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
|
@ -696,6 +696,10 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
binary_op(inputs, out, "min");
|
binary_op(inputs, out, "min");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
unary_op(inputs, out, "floor");
|
unary_op(inputs, out, "floor");
|
||||||
}
|
}
|
||||||
|
@ -43,6 +43,7 @@ NO_GPU_MULTI(CustomVJP)
|
|||||||
NO_GPU_MULTI(Depends)
|
NO_GPU_MULTI(Depends)
|
||||||
NO_GPU(Divide)
|
NO_GPU(Divide)
|
||||||
NO_GPU_MULTI(DivMod)
|
NO_GPU_MULTI(DivMod)
|
||||||
|
NO_GPU(NumberOfElements)
|
||||||
NO_GPU(Remainder)
|
NO_GPU(Remainder)
|
||||||
NO_GPU(Equal)
|
NO_GPU(Equal)
|
||||||
NO_GPU(Erf)
|
NO_GPU(Erf)
|
||||||
|
@ -74,7 +74,7 @@ bool allows_shapeless(const Primitive& p) {
|
|||||||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
||||||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
||||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
|
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
|
||||||
typeid(p) == typeid(Select);
|
typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);
|
||||||
}
|
}
|
||||||
|
|
||||||
Compiled::Compiled(
|
Compiled::Compiled(
|
||||||
|
48
mlx/ops.cpp
48
mlx/ops.cpp
@ -46,14 +46,6 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
|||||||
return {out_shape, sorted_axes};
|
return {out_shape, sorted_axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
int compute_number_of_elements(const array& a, const std::vector<int>& axes) {
|
|
||||||
int nelements = 1;
|
|
||||||
for (auto axis : axes) {
|
|
||||||
nelements *= a.shape(axis);
|
|
||||||
}
|
|
||||||
return nelements;
|
|
||||||
}
|
|
||||||
|
|
||||||
Dtype at_least_float(const Dtype& d) {
|
Dtype at_least_float(const Dtype& d) {
|
||||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||||
}
|
}
|
||||||
@ -1356,9 +1348,9 @@ array mean(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto nelements = compute_number_of_elements(a, axes);
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s);
|
auto normalizer = number_of_elements(a, axes, true, dtype, s);
|
||||||
|
return multiply(sum(a, axes, keepdims, s), normalizer, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array mean(
|
array mean(
|
||||||
@ -1391,9 +1383,12 @@ array var(
|
|||||||
auto v = subtract(a2, mu2, s);
|
auto v = subtract(a2, mu2, s);
|
||||||
|
|
||||||
if (ddof != 0) {
|
if (ddof != 0) {
|
||||||
auto nelements = compute_number_of_elements(a, axes);
|
auto nelements = number_of_elements(a, axes, false, dtype, s);
|
||||||
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0));
|
auto factor = divide(
|
||||||
v = multiply(v, array(factor, dtype), s);
|
nelements,
|
||||||
|
maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s),
|
||||||
|
s);
|
||||||
|
v = multiply(v, factor, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
return v;
|
return v;
|
||||||
@ -1770,7 +1765,7 @@ array logsumexp(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
bool keepdims /* = false */,
|
bool keepdims /* = false */,
|
||||||
StreamOrDevice s /* = {}*/) {
|
StreamOrDevice s /* = {}*/) {
|
||||||
auto maxval = stop_gradient(max(a, axes, true, s));
|
auto maxval = stop_gradient(max(a, axes, true, s), s);
|
||||||
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
|
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
|
||||||
out = add(out, reshape(maxval, out.shape(), s), s);
|
out = add(out, reshape(maxval, out.shape(), s), s);
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
@ -3600,4 +3595,29 @@ std::vector<array> atleast_3d(
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array number_of_elements(
|
||||||
|
const array& a,
|
||||||
|
std::vector<int> axes,
|
||||||
|
bool inverted,
|
||||||
|
Dtype dtype /* = int32 */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
for (auto& ax : axes) {
|
||||||
|
int normal_axis = (ax + a.ndim()) % a.ndim();
|
||||||
|
if (normal_axis >= a.ndim() || normal_axis < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[number_of_elements] Can't get the shape for axis " << ax
|
||||||
|
<< " from an array with " << a.ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
ax = normal_axis;
|
||||||
|
}
|
||||||
|
|
||||||
|
return stop_gradient(array(
|
||||||
|
std::vector<int>{},
|
||||||
|
dtype,
|
||||||
|
std::make_unique<NumberOfElements>(
|
||||||
|
to_stream(s), std::move(axes), inverted, dtype),
|
||||||
|
{a}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
11
mlx/ops.h
11
mlx/ops.h
@ -1173,4 +1173,15 @@ std::vector<array> atleast_3d(
|
|||||||
const std::vector<array>& a,
|
const std::vector<array>& a,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract the number of elements along some axes as a scalar array. Used to
|
||||||
|
* allow shape dependent shapeless compilation (pun intended).
|
||||||
|
*/
|
||||||
|
array number_of_elements(
|
||||||
|
const array& a,
|
||||||
|
std::vector<int> axes,
|
||||||
|
bool inverted,
|
||||||
|
Dtype dtype = int32,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -23,6 +23,10 @@ std::tuple<array, array, int> vmap_binary_op(
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
assert(axes.size() == 2);
|
assert(axes.size() == 2);
|
||||||
|
|
||||||
|
if (axes[0] == -1 && axes[1] == -1) {
|
||||||
|
return {inputs[0], inputs[1], -1};
|
||||||
|
}
|
||||||
|
|
||||||
auto a = inputs[0];
|
auto a = inputs[0];
|
||||||
auto b = inputs[1];
|
auto b = inputs[1];
|
||||||
int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));
|
int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));
|
||||||
@ -55,6 +59,10 @@ std::tuple<array, array, array, int> vmap_ternary_op(
|
|||||||
assert(inputs.size() == 3);
|
assert(inputs.size() == 3);
|
||||||
assert(axes.size() == 3);
|
assert(axes.size() == 3);
|
||||||
|
|
||||||
|
if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) {
|
||||||
|
return {inputs[0], inputs[1], inputs[2], -1};
|
||||||
|
}
|
||||||
|
|
||||||
auto a = inputs[0];
|
auto a = inputs[0];
|
||||||
auto b = inputs[1];
|
auto b = inputs[1];
|
||||||
auto c = inputs[2];
|
auto c = inputs[2];
|
||||||
@ -403,8 +411,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(axes.size() == 1);
|
assert(axes.size() == 1);
|
||||||
|
|
||||||
return {
|
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||||
{argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ArgPartition::is_equivalent(const Primitive& other) const {
|
bool ArgPartition::is_equivalent(const Primitive& other) const {
|
||||||
@ -420,7 +428,7 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
|
|||||||
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
|
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
int reduce_ax = axis_ + (axis_ >= axes[0]);
|
int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
std::vector<array> out;
|
std::vector<array> out;
|
||||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||||
@ -437,7 +445,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(axes.size() == 1);
|
assert(axes.size() == 1);
|
||||||
|
|
||||||
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||||
|
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int>> ArgReduce::output_shapes(
|
std::vector<std::vector<int>> ArgReduce::output_shapes(
|
||||||
@ -563,13 +572,16 @@ std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(axes.size() == 1);
|
assert(axes.size() == 1);
|
||||||
auto ax = axes[0];
|
auto ax = axes[0];
|
||||||
auto in_shape = inputs[0].shape();
|
auto in = inputs[0];
|
||||||
int diff = shape_.size() - inputs[0].ndim() + 1;
|
if (ax >= 0) {
|
||||||
assert(diff >= 0);
|
auto in_shape = in.shape();
|
||||||
in_shape.insert(in_shape.begin(), diff, 1);
|
int diff = shape_.size() - in.ndim() + 1;
|
||||||
ax += diff;
|
assert(diff >= 0);
|
||||||
shape_.insert(shape_.begin() + ax, in_shape[ax]);
|
in_shape.insert(in_shape.begin(), diff, 1);
|
||||||
auto in = reshape(inputs[0], in_shape, stream());
|
ax += diff;
|
||||||
|
shape_.insert(shape_.begin() + ax, in_shape[ax]);
|
||||||
|
in = reshape(in, in_shape, stream());
|
||||||
|
}
|
||||||
return {{broadcast_to(in, shape_, stream())}, {ax}};
|
return {{broadcast_to(in, shape_, stream())}, {ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -653,24 +665,30 @@ std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
std::vector<array> t_inputs;
|
std::vector<array> t_inputs;
|
||||||
|
int out_ax = -1;
|
||||||
// Find the first vmapped input
|
// Find the first vmapped input
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (; i < axes.size(); i++) {
|
for (; i < axes.size(); i++) {
|
||||||
t_inputs.push_back(inputs[i]);
|
t_inputs.push_back(inputs[i]);
|
||||||
if (axes[i] >= 0) {
|
if (axes[i] >= 0) {
|
||||||
|
out_ax = axes[i];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto out_ax = axes[i++];
|
if (out_ax >= 0) {
|
||||||
// Move vmap axes to the same spot.
|
// Advance to the next input
|
||||||
for (; i < axes.size(); ++i) {
|
i++;
|
||||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
|
||||||
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
// Move vmap axes to the same spot.
|
||||||
} else {
|
for (; i < axes.size(); ++i) {
|
||||||
t_inputs.push_back(inputs[i]);
|
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||||
|
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
||||||
|
} else {
|
||||||
|
t_inputs.push_back(inputs[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto axis = axis_ + (axis_ >= out_ax);
|
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
|
||||||
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
|
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1210,13 +1228,15 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
|||||||
int ax = axes[0];
|
int ax = axes[0];
|
||||||
auto fft_axes = axes_;
|
auto fft_axes = axes_;
|
||||||
auto out_shape = in.shape();
|
auto out_shape = in.shape();
|
||||||
for (auto& fft_ax : fft_axes) {
|
if (ax >= 0) {
|
||||||
if (fft_ax >= ax) {
|
for (auto& fft_ax : fft_axes) {
|
||||||
fft_ax++;
|
if (fft_ax >= ax) {
|
||||||
}
|
fft_ax++;
|
||||||
if (real_) {
|
}
|
||||||
auto n = out_shape[fft_ax];
|
if (real_) {
|
||||||
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
|
auto n = out_shape[fft_ax];
|
||||||
|
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
@ -2064,7 +2084,8 @@ std::pair<std::vector<array>, std::vector<int>> Partition::vmap(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(axes.size() == 1);
|
assert(axes.size() == 1);
|
||||||
|
|
||||||
return {{partition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||||
|
return {{partition(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Partition::is_equivalent(const Primitive& other) const {
|
bool Partition::is_equivalent(const Primitive& other) const {
|
||||||
@ -2185,7 +2206,9 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto shape = shape_;
|
auto shape = shape_;
|
||||||
shape.insert(shape.begin() + kax, key.shape()[kax]);
|
if (kax >= 0) {
|
||||||
|
shape.insert(shape.begin() + kax, key.shape()[kax]);
|
||||||
|
}
|
||||||
|
|
||||||
auto get_dtype = [width = width_]() {
|
auto get_dtype = [width = width_]() {
|
||||||
switch (width) {
|
switch (width) {
|
||||||
@ -2217,15 +2240,19 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
|
|||||||
// Transpose the input so that the vmap dim is first.
|
// Transpose the input so that the vmap dim is first.
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto ax = axes[0];
|
auto ax = axes[0];
|
||||||
std::vector<int> reorder(in.ndim());
|
if (ax >= 0) {
|
||||||
std::iota(reorder.begin(), reorder.end(), 0);
|
std::vector<int> reorder(in.ndim());
|
||||||
reorder.erase(reorder.begin() + ax);
|
std::iota(reorder.begin(), reorder.end(), 0);
|
||||||
reorder.insert(reorder.begin(), ax);
|
reorder.erase(reorder.begin() + ax);
|
||||||
// Insert the vmap dim into the shape at the beginning.
|
reorder.insert(reorder.begin(), ax);
|
||||||
auto out = transpose(in, reorder, stream());
|
// Insert the vmap dim into the shape at the beginning.
|
||||||
shape_.insert(shape_.begin(), in.shape()[ax]);
|
auto out = transpose(in, reorder, stream());
|
||||||
// Reshape the transposed input to the new shape.
|
shape_.insert(shape_.begin(), in.shape()[ax]);
|
||||||
return {{reshape(out, shape_, stream())}, {0}};
|
// Reshape the transposed input to the new shape.
|
||||||
|
return {{reshape(out, shape_, stream())}, {0}};
|
||||||
|
} else {
|
||||||
|
return {{reshape(in, shape_, stream())}, {ax}};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Reshape::vjp(
|
std::vector<array> Reshape::vjp(
|
||||||
@ -2349,9 +2376,11 @@ std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
|||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto ax = axes[0];
|
auto ax = axes[0];
|
||||||
auto reduce_axes = axes_;
|
auto reduce_axes = axes_;
|
||||||
for (auto& rax : reduce_axes) {
|
if (ax >= 0) {
|
||||||
if (rax >= ax) {
|
for (auto& rax : reduce_axes) {
|
||||||
rax++;
|
if (rax >= ax) {
|
||||||
|
rax++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@ -2424,16 +2453,13 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto out_dtype =
|
auto out_dtype =
|
||||||
(in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype();
|
(in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype();
|
||||||
|
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||||
return {
|
return {
|
||||||
{array(
|
{array(
|
||||||
in.shape(),
|
in.shape(),
|
||||||
out_dtype,
|
out_dtype,
|
||||||
std::make_unique<Scan>(
|
std::make_unique<Scan>(
|
||||||
stream(),
|
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
|
||||||
reduce_type_,
|
|
||||||
axis_ + (axes[0] <= axis_),
|
|
||||||
reverse_,
|
|
||||||
inclusive_),
|
|
||||||
{in})},
|
{in})},
|
||||||
axes};
|
axes};
|
||||||
}
|
}
|
||||||
@ -2698,9 +2724,11 @@ std::pair<std::vector<array>, std::vector<int>> Slice::vmap(
|
|||||||
auto strides = strides_;
|
auto strides = strides_;
|
||||||
auto ax = axes[0];
|
auto ax = axes[0];
|
||||||
auto& input = inputs[0];
|
auto& input = inputs[0];
|
||||||
start.insert(start.begin() + ax, 0);
|
if (ax >= 0) {
|
||||||
stop.insert(stop.begin() + ax, input.shape(ax));
|
start.insert(start.begin() + ax, 0);
|
||||||
strides.insert(strides.begin() + ax, 1);
|
stop.insert(stop.begin() + ax, input.shape(ax));
|
||||||
|
strides.insert(strides.begin() + ax, 1);
|
||||||
|
}
|
||||||
return {{slice(input, start, stop, strides, stream())}, {ax}};
|
return {{slice(input, start, stop, strides, stream())}, {ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2796,7 +2824,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
|||||||
|
|
||||||
// We are vectorizing over an axis other than the last one so keep the
|
// We are vectorizing over an axis other than the last one so keep the
|
||||||
// softmax axis unchanged
|
// softmax axis unchanged
|
||||||
if (axes[0] < inputs[0].ndim() - 1) {
|
if (axes[0] >= 0 && axes[0] < inputs[0].ndim() - 1) {
|
||||||
softmax_axes.push_back(-1);
|
softmax_axes.push_back(-1);
|
||||||
} else {
|
} else {
|
||||||
softmax_axes.push_back(-2);
|
softmax_axes.push_back(-2);
|
||||||
@ -2837,7 +2865,8 @@ std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(axes.size() == 1);
|
assert(axes.size() == 1);
|
||||||
|
|
||||||
return {{sort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
|
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||||
|
return {{sort(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Sort::vjp(
|
std::vector<array> Sort::vjp(
|
||||||
@ -2867,8 +2896,8 @@ bool Sort::is_equivalent(const Primitive& other) const {
|
|||||||
std::pair<std::vector<array>, std::vector<int>> Split::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Split::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
return {
|
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||||
{split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes};
|
return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Split::vjp(
|
std::vector<array> Split::vjp(
|
||||||
@ -2971,7 +3000,7 @@ bool Sqrt::is_equivalent(const Primitive& other) const {
|
|||||||
std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
|
std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
return {inputs, axes};
|
return {{stop_gradient(inputs[0], stream())}, axes};
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<array> Subtract::vjp(
|
std::vector<array> Subtract::vjp(
|
||||||
@ -3093,12 +3122,14 @@ std::pair<std::vector<array>, std::vector<int>> Transpose::vmap(
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(axes.size() == 1);
|
assert(axes.size() == 1);
|
||||||
auto vdim = axes[0];
|
auto vdim = axes[0];
|
||||||
for (auto& dim : axes_) {
|
if (vdim >= 0) {
|
||||||
if (dim >= vdim) {
|
for (auto& dim : axes_) {
|
||||||
dim++;
|
if (dim >= vdim) {
|
||||||
|
dim++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
axes_.insert(axes_.begin() + vdim, vdim);
|
||||||
}
|
}
|
||||||
axes_.insert(axes_.begin() + vdim, vdim);
|
|
||||||
return {{transpose(inputs[0], axes_, stream())}, {vdim}};
|
return {{transpose(inputs[0], axes_, stream())}, {vdim}};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3107,4 +3138,35 @@ bool Transpose::is_equivalent(const Primitive& other) const {
|
|||||||
return axes_ == t_other.axes_;
|
return axes_ == t_other.axes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(axes.size() == 1);
|
||||||
|
|
||||||
|
std::vector<int> new_axes = axes_;
|
||||||
|
auto vdim = axes[0];
|
||||||
|
if (vdim >= 0) {
|
||||||
|
for (auto& dim : new_axes) {
|
||||||
|
if (dim >= vdim) {
|
||||||
|
dim++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array out = array(
|
||||||
|
std::vector<int>{},
|
||||||
|
dtype_,
|
||||||
|
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||||
|
inputs);
|
||||||
|
|
||||||
|
return {{out}, {-1}};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NumberOfElements::is_equivalent(const Primitive& other) const {
|
||||||
|
const NumberOfElements& n_other = static_cast<const NumberOfElements&>(other);
|
||||||
|
return axes_ == n_other.axes_ && inverted_ == n_other.inverted_ &&
|
||||||
|
dtype_ == n_other.dtype_;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1229,6 +1229,37 @@ class NotEqual : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class NumberOfElements : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit NumberOfElements(
|
||||||
|
Stream stream,
|
||||||
|
std::vector<int> axes,
|
||||||
|
bool inverted,
|
||||||
|
Dtype dtype)
|
||||||
|
: UnaryPrimitive(stream),
|
||||||
|
axes_(std::move(axes)),
|
||||||
|
inverted_(inverted),
|
||||||
|
dtype_(dtype) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_PRINT(NumberOfElements)
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<std::vector<int>> output_shapes(
|
||||||
|
const std::vector<array>& inputs) override {
|
||||||
|
return {{}};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int> axes_;
|
||||||
|
bool inverted_;
|
||||||
|
Dtype dtype_;
|
||||||
|
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class Pad : public UnaryPrimitive {
|
class Pad : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Pad(
|
explicit Pad(
|
||||||
|
@ -653,6 +653,7 @@ std::vector<array> vmap_replace(
|
|||||||
v_axes.push_back(-1);
|
v_axes.push_back(-1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
||||||
// For each primitive's outputs add its id, the vout id and the vax
|
// For each primitive's outputs add its id, the vout id and the vax
|
||||||
auto outputs = a.outputs();
|
auto outputs = a.outputs();
|
||||||
|
@ -653,6 +653,24 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(expected[0], out[0]))
|
self.assertTrue(mx.allclose(expected[0], out[0]))
|
||||||
self.assertTrue(mx.allclose(expected[1], out[1]))
|
self.assertTrue(mx.allclose(expected[1], out[1]))
|
||||||
|
|
||||||
|
def test_shapeless_mean(self):
|
||||||
|
def mean(x):
|
||||||
|
return mx.mean(x, keepdims=True)
|
||||||
|
|
||||||
|
cmean = mx.compile(mean, shapeless=True)
|
||||||
|
|
||||||
|
x = mx.ones(2)
|
||||||
|
out = cmean(x)
|
||||||
|
self.assertTrue(mx.allclose(out, mean(x)))
|
||||||
|
|
||||||
|
x = mx.ones(4)
|
||||||
|
out = cmean(x)
|
||||||
|
self.assertTrue(mx.allclose(out, mean(x)))
|
||||||
|
|
||||||
|
x = mx.ones(7)
|
||||||
|
out = cmean(x)
|
||||||
|
self.assertTrue(mx.allclose(out, mean(x)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
expected = mx.array([2, 1])
|
expected = mx.array([2, 1])
|
||||||
self.assertTrue(mx.array_equal(out, expected))
|
self.assertTrue(mx.array_equal(out, expected))
|
||||||
|
|
||||||
|
def test_vmap_mean(self):
|
||||||
|
a = mx.arange(8).reshape(2, 4)
|
||||||
|
out = mx.vmap(mx.mean)(a)
|
||||||
|
expected = mx.mean(a, axis=1)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
a = mx.arange(16).reshape(2, 2, 4)
|
||||||
|
out = mx.vmap(mx.vmap(mx.mean))(a)
|
||||||
|
expected = mx.mean(a, axis=2)
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
def test_mismatch_input_sizes(self):
|
def test_mismatch_input_sizes(self):
|
||||||
a = mx.ones((10, 1))
|
a = mx.ones((10, 1))
|
||||||
b = mx.ones((1, 1, 1, 5))
|
b = mx.ones((1, 1, 1, 5))
|
||||||
|
Loading…
Reference in New Issue
Block a user