mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
ExpandDims
primitive (#1687)
* add squeeze primitive * simplify squeeze, use in gather * fix * fix * fix * fix * fix no cpu * use squeeze in matmul and friends * expand dims primitive * comment
This commit is contained in:
parent
310ad8d9db
commit
f76a49e555
@ -43,6 +43,7 @@ DEFAULT(NumberOfElements)
|
|||||||
DEFAULT(Equal)
|
DEFAULT(Equal)
|
||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
DEFAULT(ErfInv)
|
DEFAULT(ErfInv)
|
||||||
|
DEFAULT(ExpandDims)
|
||||||
DEFAULT(FFT)
|
DEFAULT(FFT)
|
||||||
DEFAULT(Floor)
|
DEFAULT(Floor)
|
||||||
DEFAULT(Gather)
|
DEFAULT(Gather)
|
||||||
@ -76,6 +77,7 @@ DEFAULT(Slice)
|
|||||||
DEFAULT(SliceUpdate)
|
DEFAULT(SliceUpdate)
|
||||||
DEFAULT_MULTI(Split)
|
DEFAULT_MULTI(Split)
|
||||||
DEFAULT(Sort)
|
DEFAULT(Sort)
|
||||||
|
DEFAULT(Squeeze)
|
||||||
DEFAULT(StopGradient)
|
DEFAULT(StopGradient)
|
||||||
DEFAULT_MULTI(SVD)
|
DEFAULT_MULTI(SVD)
|
||||||
DEFAULT(Transpose)
|
DEFAULT(Transpose)
|
||||||
|
@ -85,6 +85,16 @@ void Depends::eval(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
auto strides = in.strides();
|
||||||
|
for (auto ax : axes_) {
|
||||||
|
strides.insert(strides.begin() + ax, 1);
|
||||||
|
}
|
||||||
|
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
@ -248,6 +258,20 @@ void Split::eval(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
Strides strides;
|
||||||
|
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||||
|
if (j < axes_.size() && i == axes_[j]) {
|
||||||
|
j++;
|
||||||
|
} else {
|
||||||
|
strides.push_back(in.strides(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
move_or_copy(inputs[0], out);
|
move_or_copy(inputs[0], out);
|
||||||
|
@ -57,6 +57,7 @@ DEFAULT(Equal)
|
|||||||
DEFAULT(Erf)
|
DEFAULT(Erf)
|
||||||
DEFAULT(ErfInv)
|
DEFAULT(ErfInv)
|
||||||
DEFAULT(Exp)
|
DEFAULT(Exp)
|
||||||
|
DEFAULT(ExpandDims)
|
||||||
DEFAULT(Expm1)
|
DEFAULT(Expm1)
|
||||||
DEFAULT(FFT)
|
DEFAULT(FFT)
|
||||||
DEFAULT(Floor)
|
DEFAULT(Floor)
|
||||||
@ -101,6 +102,7 @@ DEFAULT(Softmax)
|
|||||||
DEFAULT(Sort)
|
DEFAULT(Sort)
|
||||||
DEFAULT_MULTI(Split)
|
DEFAULT_MULTI(Split)
|
||||||
DEFAULT(Square)
|
DEFAULT(Square)
|
||||||
|
DEFAULT(Squeeze)
|
||||||
DEFAULT(Sqrt)
|
DEFAULT(Sqrt)
|
||||||
DEFAULT(StopGradient)
|
DEFAULT(StopGradient)
|
||||||
DEFAULT(Subtract)
|
DEFAULT(Subtract)
|
||||||
|
@ -211,6 +211,10 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
copy_gpu(in, out, ctype);
|
copy_gpu(in, out, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
auto read_task = [out = out,
|
auto read_task = [out = out,
|
||||||
@ -381,6 +385,10 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/* const Stream& s = */ stream());
|
/* const Stream& s = */ stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
@ -55,6 +55,7 @@ NO_CPU(Equal)
|
|||||||
NO_CPU(Erf)
|
NO_CPU(Erf)
|
||||||
NO_CPU(ErfInv)
|
NO_CPU(ErfInv)
|
||||||
NO_CPU(Exp)
|
NO_CPU(Exp)
|
||||||
|
NO_CPU(ExpandDims)
|
||||||
NO_CPU(Expm1)
|
NO_CPU(Expm1)
|
||||||
NO_CPU(FFT)
|
NO_CPU(FFT)
|
||||||
NO_CPU(Floor)
|
NO_CPU(Floor)
|
||||||
@ -104,6 +105,7 @@ NO_CPU(Softmax)
|
|||||||
NO_CPU(Sort)
|
NO_CPU(Sort)
|
||||||
NO_CPU_MULTI(Split)
|
NO_CPU_MULTI(Split)
|
||||||
NO_CPU(Square)
|
NO_CPU(Square)
|
||||||
|
NO_CPU(Squeeze)
|
||||||
NO_CPU(Sqrt)
|
NO_CPU(Sqrt)
|
||||||
NO_CPU(StopGradient)
|
NO_CPU(StopGradient)
|
||||||
NO_CPU(Subtract)
|
NO_CPU(Subtract)
|
||||||
|
@ -55,6 +55,7 @@ NO_GPU(Equal)
|
|||||||
NO_GPU(Erf)
|
NO_GPU(Erf)
|
||||||
NO_GPU(ErfInv)
|
NO_GPU(ErfInv)
|
||||||
NO_GPU(Exp)
|
NO_GPU(Exp)
|
||||||
|
NO_GPU(ExpandDims)
|
||||||
NO_GPU(Expm1)
|
NO_GPU(Expm1)
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
NO_GPU(Floor)
|
NO_GPU(Floor)
|
||||||
@ -104,6 +105,7 @@ NO_GPU(Softmax)
|
|||||||
NO_GPU(Sort)
|
NO_GPU(Sort)
|
||||||
NO_GPU_MULTI(Split)
|
NO_GPU_MULTI(Split)
|
||||||
NO_GPU(Square)
|
NO_GPU(Square)
|
||||||
|
NO_GPU(Squeeze)
|
||||||
NO_GPU(Sqrt)
|
NO_GPU(Sqrt)
|
||||||
NO_GPU(StopGradient)
|
NO_GPU(StopGradient)
|
||||||
NO_GPU(Subtract)
|
NO_GPU(Subtract)
|
||||||
|
@ -81,6 +81,7 @@ bool allows_shapeless(const Primitive& p) {
|
|||||||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
|
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
|
||||||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
||||||
|
typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) ||
|
||||||
typeid(p) == typeid(fast::AffineQuantize) ||
|
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||||
typeid(p) == typeid(fast::LayerNorm) ||
|
typeid(p) == typeid(fast::LayerNorm) ||
|
||||||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||||
|
295
mlx/ops.cpp
295
mlx/ops.cpp
@ -20,7 +20,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
|
std::tuple<Shape, std::vector<int>, bool> compute_reduce_shape(
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
bool is_noop = true;
|
bool is_noop = true;
|
||||||
@ -40,18 +40,16 @@ std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
|
|||||||
throw std::invalid_argument("Duplicate axes detected in reduction.");
|
throw std::invalid_argument("Duplicate axes detected in reduction.");
|
||||||
}
|
}
|
||||||
Shape out_shape;
|
Shape out_shape;
|
||||||
Shape squeezed_shape;
|
|
||||||
for (int i = 0; i < ndim; ++i) {
|
for (int i = 0; i < ndim; ++i) {
|
||||||
if (axes_set.count(i) == 0) {
|
if (axes_set.count(i) == 0) {
|
||||||
out_shape.push_back(shape[i]);
|
out_shape.push_back(shape[i]);
|
||||||
squeezed_shape.push_back(shape[i]);
|
|
||||||
} else {
|
} else {
|
||||||
out_shape.push_back(1);
|
out_shape.push_back(1);
|
||||||
}
|
}
|
||||||
is_noop &= (out_shape.back() == shape[i]);
|
is_noop &= (out_shape.back() == shape[i]);
|
||||||
}
|
}
|
||||||
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
|
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
|
||||||
return {out_shape, sorted_axes, squeezed_shape, is_noop};
|
return {out_shape, sorted_axes, is_noop};
|
||||||
}
|
}
|
||||||
|
|
||||||
Dtype at_least_float(const Dtype& d) {
|
Dtype at_least_float(const Dtype& d) {
|
||||||
@ -460,54 +458,51 @@ array hadamard_transform(
|
|||||||
{astype(a, dtype, s)});
|
{astype(a, dtype, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array squeeze_impl(
|
||||||
|
const array& a,
|
||||||
|
std::vector<int> axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
for (auto& ax : axes) {
|
||||||
|
auto new_ax = ax < 0 ? ax + a.ndim() : ax;
|
||||||
|
if (new_ax < 0 || new_ax >= a.ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (a.shape(new_ax) != 1) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[squeeze] Cannot squeeze axis " << ax << " with size "
|
||||||
|
<< a.shape(ax) << " which is not equal to 1.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
ax = new_ax;
|
||||||
|
}
|
||||||
|
auto shape = Squeeze::output_shape(a, axes);
|
||||||
|
return array(
|
||||||
|
std::move(shape),
|
||||||
|
a.dtype(),
|
||||||
|
std::make_shared<Squeeze>(to_stream(s), std::move(axes)),
|
||||||
|
{a});
|
||||||
|
}
|
||||||
|
|
||||||
array squeeze(
|
array squeeze(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
std::set<int> unique_axes;
|
std::set<int> unique_axes;
|
||||||
for (auto ax : axes) {
|
for (auto ax : axes) {
|
||||||
ax = ax < 0 ? ax + a.ndim() : ax;
|
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
|
||||||
if (ax < 0 || ax >= a.ndim()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
|
|
||||||
<< " dimensions.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (a.shape(ax) != 1) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[squeeze] Cannot squeeze axis " << ax << " with size "
|
|
||||||
<< a.shape(ax) << " which is not equal to 1.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
unique_axes.insert(ax);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (unique_axes.size() != axes.size()) {
|
if (unique_axes.size() != axes.size()) {
|
||||||
throw std::invalid_argument("[squeeze] Received duplicate axes.");
|
throw std::invalid_argument("[squeeze] Received duplicate axes.");
|
||||||
}
|
}
|
||||||
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
|
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
|
||||||
Shape shape;
|
return squeeze_impl(a, std::move(sorted_axes), s);
|
||||||
for (int i = 0, j = 0; i < a.ndim(); ++i) {
|
|
||||||
if (j < sorted_axes.size() && i == sorted_axes[j]) {
|
|
||||||
j++;
|
|
||||||
} else {
|
|
||||||
shape.push_back(a.shape(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return reshape(a, std::move(shape), s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
||||||
int ax = axis < 0 ? axis + a.ndim() : axis;
|
return squeeze_impl(a, {axis}, s);
|
||||||
if (ax < 0 || ax >= a.ndim()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
|
|
||||||
<< " dimensions.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
auto shape = a.shape();
|
|
||||||
shape.erase(shape.begin() + ax);
|
|
||||||
return reshape(a, std::move(shape), s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
|
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
@ -517,21 +512,34 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
axes.push_back(i);
|
axes.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return squeeze(a, axes, s);
|
return squeeze_impl(a, std::move(axes), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array expand_dims_impl(
|
||||||
|
const array& a,
|
||||||
|
std::vector<int> axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto out_ndim = a.ndim() + axes.size();
|
||||||
|
for (auto& ax : axes) {
|
||||||
|
auto new_ax = ax < 0 ? ax + out_ndim : ax;
|
||||||
|
if (new_ax < 0 || new_ax >= out_ndim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
|
||||||
|
<< a.ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
ax = new_ax;
|
||||||
|
}
|
||||||
|
auto shape = ExpandDims::output_shape(a, axes);
|
||||||
|
return array(
|
||||||
|
std::move(shape),
|
||||||
|
a.dtype(),
|
||||||
|
std::make_shared<ExpandDims>(to_stream(s), std::move(axes)),
|
||||||
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
||||||
int out_dim = a.ndim() + 1;
|
return expand_dims_impl(a, {axis}, s);
|
||||||
int ax = axis < 0 ? axis + out_dim : axis;
|
|
||||||
if (ax < 0 || ax >= out_dim) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[expand_dims] Invalid axis " << axis << " for output array with "
|
|
||||||
<< a.ndim() << " dimensions.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
auto shape = a.shape();
|
|
||||||
shape.insert(shape.begin() + ax, 1);
|
|
||||||
return reshape(a, std::move(shape), s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array expand_dims(
|
array expand_dims(
|
||||||
@ -544,31 +552,17 @@ array expand_dims(
|
|||||||
throw std::invalid_argument("[expand_dims] Received duplicate axes.");
|
throw std::invalid_argument("[expand_dims] Received duplicate axes.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int out_ndim = axes.size() + a.ndim();
|
|
||||||
std::vector<int> canonical_axes = axes;
|
|
||||||
for (auto& ax : canonical_axes) {
|
|
||||||
ax = ax < 0 ? ax + out_ndim : ax;
|
|
||||||
if (ax < 0 || ax >= out_ndim) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
|
|
||||||
<< a.ndim() << " dimensions.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for repeats again
|
// Check for repeats again
|
||||||
std::set<int> unique_axes(canonical_axes.begin(), canonical_axes.end());
|
auto out_ndim = a.ndim() + axes.size();
|
||||||
|
std::set<int> unique_axes;
|
||||||
|
for (auto ax : axes) {
|
||||||
|
unique_axes.insert(ax < 0 ? ax + out_ndim : ax);
|
||||||
|
}
|
||||||
if (unique_axes.size() != axes.size()) {
|
if (unique_axes.size() != axes.size()) {
|
||||||
throw std::invalid_argument("[expand_dims] Received duplicate axes.");
|
throw std::invalid_argument("[expand_dims] Received duplicate axes.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
|
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
|
||||||
auto out_shape = a.shape();
|
return expand_dims_impl(a, std::move(sorted_axes), s);
|
||||||
for (int i = 0; i < sorted_axes.size(); ++i) {
|
|
||||||
out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
|
|
||||||
}
|
|
||||||
return reshape(a, std::move(out_shape), s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slice helper
|
// Slice helper
|
||||||
@ -1519,7 +1513,7 @@ array all(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
bool keepdims /* = false */,
|
bool keepdims /* = false */,
|
||||||
StreamOrDevice s /* = {}*/) {
|
StreamOrDevice s /* = {}*/) {
|
||||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
auto [out_shape, sorted_axes, is_noop] =
|
||||||
compute_reduce_shape(axes, a.shape());
|
compute_reduce_shape(axes, a.shape());
|
||||||
auto out = (is_noop)
|
auto out = (is_noop)
|
||||||
? astype(a, bool_, s)
|
? astype(a, bool_, s)
|
||||||
@ -1529,7 +1523,7 @@ array all(
|
|||||||
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = reshape(out, std::move(squeezed_shape), s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1553,7 +1547,7 @@ array any(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
bool keepdims /* = false */,
|
bool keepdims /* = false */,
|
||||||
StreamOrDevice s /* = {}*/) {
|
StreamOrDevice s /* = {}*/) {
|
||||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
auto [out_shape, sorted_axes, is_noop] =
|
||||||
compute_reduce_shape(axes, a.shape());
|
compute_reduce_shape(axes, a.shape());
|
||||||
auto out = (is_noop)
|
auto out = (is_noop)
|
||||||
? astype(a, bool_, s)
|
? astype(a, bool_, s)
|
||||||
@ -1563,7 +1557,7 @@ array any(
|
|||||||
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = reshape(out, std::move(squeezed_shape), s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1590,7 +1584,7 @@ array sum(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
auto [out_shape, sorted_axes, is_noop] =
|
||||||
compute_reduce_shape(axes, a.shape());
|
compute_reduce_shape(axes, a.shape());
|
||||||
Dtype out_type = a.dtype();
|
Dtype out_type = a.dtype();
|
||||||
if (issubdtype(a.dtype(), signedinteger)) {
|
if (issubdtype(a.dtype(), signedinteger)) {
|
||||||
@ -1608,7 +1602,7 @@ array sum(
|
|||||||
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = reshape(out, std::move(squeezed_shape), s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1742,7 +1736,7 @@ array prod(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
auto [out_shape, sorted_axes, is_noop] =
|
||||||
compute_reduce_shape(axes, a.shape());
|
compute_reduce_shape(axes, a.shape());
|
||||||
Dtype out_type = a.dtype();
|
Dtype out_type = a.dtype();
|
||||||
if (issubdtype(a.dtype(), signedinteger)) {
|
if (issubdtype(a.dtype(), signedinteger)) {
|
||||||
@ -1760,7 +1754,7 @@ array prod(
|
|||||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = reshape(out, std::move(squeezed_shape), s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1787,7 +1781,7 @@ array max(
|
|||||||
if (a.size() == 0) {
|
if (a.size() == 0) {
|
||||||
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
|
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
auto [out_shape, sorted_axes, is_noop] =
|
||||||
compute_reduce_shape(axes, a.shape());
|
compute_reduce_shape(axes, a.shape());
|
||||||
auto out = (is_noop)
|
auto out = (is_noop)
|
||||||
? a
|
? a
|
||||||
@ -1797,7 +1791,7 @@ array max(
|
|||||||
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = reshape(out, std::move(squeezed_shape), s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1827,7 +1821,7 @@ array min(
|
|||||||
if (axes.empty()) {
|
if (axes.empty()) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
auto [out_shape, sorted_axes, is_noop] =
|
||||||
compute_reduce_shape(axes, a.shape());
|
compute_reduce_shape(axes, a.shape());
|
||||||
auto out = (is_noop)
|
auto out = (is_noop)
|
||||||
? a
|
? a
|
||||||
@ -1837,7 +1831,7 @@ array min(
|
|||||||
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = reshape(out, std::move(squeezed_shape), s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1870,7 +1864,7 @@ array argmin(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[argmin] Cannot argmin reduce zero size array.");
|
"[argmin] Cannot argmin reduce zero size array.");
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
auto [out_shape, sorted_axes, is_noop] =
|
||||||
compute_reduce_shape({axis}, a.shape());
|
compute_reduce_shape({axis}, a.shape());
|
||||||
auto out = (is_noop)
|
auto out = (is_noop)
|
||||||
? zeros(out_shape, uint32, s)
|
? zeros(out_shape, uint32, s)
|
||||||
@ -1881,7 +1875,7 @@ array argmin(
|
|||||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = reshape(out, std::move(squeezed_shape), s);
|
out = squeeze(out, sorted_axes[0], s);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -1906,7 +1900,7 @@ array argmax(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[argmax] Cannot argmax reduce zero size array.");
|
"[argmax] Cannot argmax reduce zero size array.");
|
||||||
}
|
}
|
||||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
auto [out_shape, sorted_axes, is_noop] =
|
||||||
compute_reduce_shape({axis}, a.shape());
|
compute_reduce_shape({axis}, a.shape());
|
||||||
auto out = (is_noop)
|
auto out = (is_noop)
|
||||||
? zeros(out_shape, uint32, s)
|
? zeros(out_shape, uint32, s)
|
||||||
@ -1917,7 +1911,7 @@ array argmax(
|
|||||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = reshape(out, std::move(squeezed_shape), s);
|
out = squeeze(out, sorted_axes[0], s);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -2544,11 +2538,11 @@ array matmul(
|
|||||||
}
|
}
|
||||||
if (a.ndim() == 1) {
|
if (a.ndim() == 1) {
|
||||||
// Insert a singleton dim in the beginning
|
// Insert a singleton dim in the beginning
|
||||||
a = reshape(a, {1, -1}, s);
|
a = expand_dims(a, 0, s);
|
||||||
}
|
}
|
||||||
if (b.ndim() == 1) {
|
if (b.ndim() == 1) {
|
||||||
// Insert a singleton dim at the end
|
// Insert a singleton dim at the end
|
||||||
b = reshape(b, {-1, 1}, s);
|
b = expand_dims(b, 1, s);
|
||||||
}
|
}
|
||||||
if (a.shape(-1) != b.shape(-2)) {
|
if (a.shape(-1) != b.shape(-2)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -2608,17 +2602,21 @@ array matmul(
|
|||||||
auto out_shape = a.shape();
|
auto out_shape = a.shape();
|
||||||
out_shape.back() = b.shape(-1);
|
out_shape.back() = b.shape(-1);
|
||||||
|
|
||||||
auto p = std::make_shared<Matmul>(to_stream(s));
|
auto out = array(
|
||||||
|
std::move(out_shape),
|
||||||
|
out_type,
|
||||||
|
std::make_shared<Matmul>(to_stream(s)),
|
||||||
|
{a, b});
|
||||||
|
|
||||||
// Remove the possibly inserted singleton dimensions
|
// Remove the possibly inserted singleton dimensions
|
||||||
if (in_a.ndim() == 1 || in_b.ndim() == 1) {
|
std::vector<int> axes;
|
||||||
auto out = array(out_shape, out_type, std::move(p), {a, b});
|
if (in_a.ndim() == 1) {
|
||||||
out_shape.erase(
|
axes.push_back(out.ndim() - 2);
|
||||||
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
|
|
||||||
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
|
|
||||||
return reshape(out, std::move(out_shape), s);
|
|
||||||
}
|
}
|
||||||
return array(std::move(out_shape), out_type, std::move(p), {a, b});
|
if (in_b.ndim() == 1) {
|
||||||
|
axes.push_back(out.ndim() - 1);
|
||||||
|
}
|
||||||
|
return axes.empty() ? out : squeeze(out, axes, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array gather(
|
array gather(
|
||||||
@ -2658,15 +2656,6 @@ array gather(
|
|||||||
<< " for array with " << a.ndim() << " dimensions.";
|
<< " for array with " << a.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
for (int i = 0; i < a.ndim(); ++i) {
|
|
||||||
if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got "
|
|
||||||
<< slice_sizes << " for array with shape " << a.shape() << ".";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Promote indices to the same type
|
// Promote indices to the same type
|
||||||
auto dtype = result_type(indices);
|
auto dtype = result_type(indices);
|
||||||
if (issubdtype(dtype, inexact)) {
|
if (issubdtype(dtype, inexact)) {
|
||||||
@ -2680,6 +2669,29 @@ array gather(
|
|||||||
idx = astype(idx, dtype, s);
|
idx = astype(idx, dtype, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (a.size() == 0) {
|
||||||
|
// Empty input, either the total slice size is 0 or the indices are empty
|
||||||
|
auto total_slice = std::accumulate(
|
||||||
|
slice_sizes.begin(), slice_sizes.end(), 1, std::multiplies<int64_t>{});
|
||||||
|
auto idx_size = !inputs.empty() ? inputs[0].size() : 1;
|
||||||
|
if (idx_size != 0 && total_slice != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[gather] If the input is empty, either the indices must be"
|
||||||
|
<< " empty or the total slice size must be 0.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-empty input, check slice sizes are valid
|
||||||
|
for (int i = 0; i < a.ndim(); ++i) {
|
||||||
|
if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got "
|
||||||
|
<< slice_sizes << " for array with shape " << a.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Shape out_shape;
|
Shape out_shape;
|
||||||
if (!inputs.empty()) {
|
if (!inputs.empty()) {
|
||||||
out_shape = inputs[0].shape();
|
out_shape = inputs[0].shape();
|
||||||
@ -2688,9 +2700,10 @@ array gather(
|
|||||||
|
|
||||||
inputs.insert(inputs.begin(), a);
|
inputs.insert(inputs.begin(), a);
|
||||||
return array(
|
return array(
|
||||||
out_shape,
|
std::move(out_shape),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_shared<Gather>(to_stream(s), axes, slice_sizes),
|
std::make_shared<Gather>(
|
||||||
|
to_stream(s), std::move(axes), std::move(slice_sizes)),
|
||||||
inputs);
|
inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2719,7 +2732,7 @@ array take(
|
|||||||
|
|
||||||
// Make slice sizes to pass to gather
|
// Make slice sizes to pass to gather
|
||||||
Shape slice_sizes = a.shape();
|
Shape slice_sizes = a.shape();
|
||||||
slice_sizes[axis] = indices.size() > 0 ? 1 : 0;
|
slice_sizes[axis] = 1;
|
||||||
|
|
||||||
auto out = gather(a, indices, axis, slice_sizes, s);
|
auto out = gather(a, indices, axis, slice_sizes, s);
|
||||||
|
|
||||||
@ -2736,9 +2749,7 @@ array take(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Squeeze the axis we take over
|
// Squeeze the axis we take over
|
||||||
auto out_shape = out.shape();
|
return squeeze(out, indices.ndim() + axis, s);
|
||||||
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
|
|
||||||
return reshape(out, std::move(out_shape), s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
|
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
|
||||||
@ -2811,12 +2822,14 @@ array take_along_axis(
|
|||||||
}
|
}
|
||||||
std::vector<int> dims(a.ndim());
|
std::vector<int> dims(a.ndim());
|
||||||
std::iota(dims.begin(), dims.end(), 0);
|
std::iota(dims.begin(), dims.end(), 0);
|
||||||
Shape slice_sizes(a.ndim(), a.size() > 0);
|
Shape slice_sizes(a.ndim(), 1);
|
||||||
auto out = gather(a, nd_indices, dims, slice_sizes, s);
|
auto out = gather(a, nd_indices, dims, slice_sizes, s);
|
||||||
|
|
||||||
// Squeeze out the slice shape
|
// Squeeze out the slice shape
|
||||||
Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim());
|
for (auto& d : dims) {
|
||||||
return reshape(out, std::move(out_shape), s);
|
d += a.ndim();
|
||||||
|
}
|
||||||
|
return squeeze(out, dims, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array put_along_axis(
|
array put_along_axis(
|
||||||
@ -3935,17 +3948,20 @@ array addmm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<AddMM>(to_stream(s), alpha, beta),
|
std::make_shared<AddMM>(to_stream(s), alpha, beta),
|
||||||
{a, b, c});
|
{a, b, c});
|
||||||
|
|
||||||
// Remove the possibly inserted singleton dimensions
|
// Remove the possibly inserted singleton dimensions
|
||||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
std::vector<int> axes;
|
||||||
out = reshape(out, out_shape_adjusted, s);
|
if (in_a_ndim == 1) {
|
||||||
|
axes.push_back(out.ndim() - 2);
|
||||||
}
|
}
|
||||||
|
if (in_b_ndim == 1) {
|
||||||
return out;
|
axes.push_back(out.ndim() - 1);
|
||||||
|
}
|
||||||
|
return axes.empty() ? out : squeeze(out, axes, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Compute matrix product with tile-level masking */
|
/** Compute matrix product with tile-level masking */
|
||||||
@ -3986,11 +4002,11 @@ array block_masked_mm(
|
|||||||
|
|
||||||
if (a.ndim() == 1) {
|
if (a.ndim() == 1) {
|
||||||
// Insert a singleton dim in the beginning
|
// Insert a singleton dim in the beginning
|
||||||
a = reshape(a, {1, -1}, s);
|
a = expand_dims(a, 0, s);
|
||||||
}
|
}
|
||||||
if (b.ndim() == 1) {
|
if (b.ndim() == 1) {
|
||||||
// Insert a singleton dim at the end
|
// Insert a singleton dim at the end
|
||||||
b = reshape(b, {-1, 1}, s);
|
b = expand_dims(b, 1, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (a.shape(-1) != b.shape(-2)) {
|
if (a.shape(-1) != b.shape(-2)) {
|
||||||
@ -4110,20 +4126,19 @@ array block_masked_mm(
|
|||||||
|
|
||||||
// Caculate array
|
// Caculate array
|
||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<BlockMaskedMM>(to_stream(s), block_size),
|
std::make_shared<BlockMaskedMM>(to_stream(s), block_size),
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
|
|
||||||
// Remove the possibly inserted singleton dimensions
|
// Remove the possibly inserted singleton dimensions
|
||||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
std::vector<int> axes;
|
||||||
out_shape.erase(
|
if (in_a_ndim == 1) {
|
||||||
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
|
axes.push_back(out.ndim() - 2);
|
||||||
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
|
|
||||||
out = reshape(out, out_shape, s);
|
|
||||||
}
|
}
|
||||||
|
if (in_b_ndim == 1) {
|
||||||
return out;
|
axes.push_back(out.ndim() - 1);
|
||||||
|
}
|
||||||
|
return axes.empty() ? out : squeeze(out, axes, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Compute matrix product with matrix-level gather */
|
/** Compute matrix product with matrix-level gather */
|
||||||
@ -4150,11 +4165,11 @@ array gather_mm(
|
|||||||
|
|
||||||
if (a.ndim() == 1) {
|
if (a.ndim() == 1) {
|
||||||
// Insert a singleton dim in the beginning
|
// Insert a singleton dim in the beginning
|
||||||
a = reshape(a, {1, -1}, s);
|
a = expand_dims(a, 0, s);
|
||||||
}
|
}
|
||||||
if (b.ndim() == 1) {
|
if (b.ndim() == 1) {
|
||||||
// Insert a singleton dim at the end
|
// Insert a singleton dim at the end
|
||||||
b = reshape(b, {-1, 1}, s);
|
b = expand_dims(b, 1, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (a.shape(-1) != b.shape(-2)) {
|
if (a.shape(-1) != b.shape(-2)) {
|
||||||
@ -4212,20 +4227,20 @@ array gather_mm(
|
|||||||
|
|
||||||
// Caculate array
|
// Caculate array
|
||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<GatherMM>(to_stream(s)),
|
std::make_shared<GatherMM>(to_stream(s)),
|
||||||
{a, b, lhs_indices, rhs_indices});
|
{a, b, lhs_indices, rhs_indices});
|
||||||
|
|
||||||
// Remove the possibly inserted singleton dimensions
|
// Remove the possibly inserted singleton dimensions
|
||||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
std::vector<int> axes;
|
||||||
out_shape.erase(
|
if (in_a_ndim == 1) {
|
||||||
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
|
axes.push_back(out.ndim() - 2);
|
||||||
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
|
|
||||||
out = reshape(out, out_shape, s);
|
|
||||||
}
|
}
|
||||||
|
if (in_b_ndim == 1) {
|
||||||
return out;
|
axes.push_back(out.ndim() - 1);
|
||||||
|
}
|
||||||
|
return axes.empty() ? out : squeeze(out, axes, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array diagonal(
|
array diagonal(
|
||||||
|
@ -1602,6 +1602,55 @@ std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
|
|||||||
return {{expm1(inputs[0], stream())}, axes};
|
return {{expm1(inputs[0], stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> ExpandDims::vjp(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>&,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
return {squeeze(cotangents[0], axes_, stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> ExpandDims::jvp(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>&) {
|
||||||
|
return {expand_dims(tangents[0], axes_, stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> ExpandDims::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto ax = axes[0];
|
||||||
|
auto expand_axes = axes_;
|
||||||
|
for (auto& s : expand_axes) {
|
||||||
|
if (s >= axes[0]) {
|
||||||
|
s++;
|
||||||
|
} else {
|
||||||
|
ax++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {{expand_dims(inputs[0], std::move(expand_axes), stream())}, {ax}};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExpandDims::is_equivalent(const Primitive& other) const {
|
||||||
|
const ExpandDims& a_other = static_cast<const ExpandDims&>(other);
|
||||||
|
return (axes_ == a_other.axes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape ExpandDims::output_shape(
|
||||||
|
const array& input,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto shape = input.shape();
|
||||||
|
for (auto ax : axes) {
|
||||||
|
shape.insert(shape.begin() + ax, 1);
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
return {ExpandDims::output_shape(inputs[0], axes_)};
|
||||||
|
}
|
||||||
|
|
||||||
bool FFT::is_equivalent(const Primitive& other) const {
|
bool FFT::is_equivalent(const Primitive& other) const {
|
||||||
const FFT& r_other = static_cast<const FFT&>(other);
|
const FFT& r_other = static_cast<const FFT&>(other);
|
||||||
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
||||||
@ -1846,7 +1895,7 @@ bool Gather::is_equivalent(const Primitive& other) const {
|
|||||||
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
|
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
|
||||||
Shape out_shape;
|
Shape out_shape;
|
||||||
if (inputs.size() > 1) {
|
if (inputs.size() > 1) {
|
||||||
out_shape = inputs[0].shape();
|
out_shape = inputs[1].shape();
|
||||||
}
|
}
|
||||||
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
|
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
|
||||||
return {std::move(out_shape)};
|
return {std::move(out_shape)};
|
||||||
@ -3847,6 +3896,57 @@ std::pair<std::vector<array>, std::vector<int>> Subtract::vmap(
|
|||||||
return {{subtract(a, b, stream())}, {to_ax}};
|
return {{subtract(a, b, stream())}, {to_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Squeeze::vjp(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>&,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
return {expand_dims(cotangents[0], axes_, stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Squeeze::jvp(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>&) {
|
||||||
|
return {squeeze(tangents[0], axes_, stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Squeeze::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto ax = axes[0];
|
||||||
|
auto squeeze_axes = axes_;
|
||||||
|
for (auto& s : squeeze_axes) {
|
||||||
|
if (s >= axes[0]) {
|
||||||
|
s++;
|
||||||
|
} else {
|
||||||
|
ax--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {{squeeze(inputs[0], std::move(squeeze_axes), stream())}, {ax}};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Squeeze::is_equivalent(const Primitive& other) const {
|
||||||
|
const Squeeze& a_other = static_cast<const Squeeze&>(other);
|
||||||
|
return (axes_ == a_other.axes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape Squeeze::output_shape(const array& input, const std::vector<int>& axes) {
|
||||||
|
Shape shape;
|
||||||
|
for (int i = 0, j = 0; i < input.ndim(); ++i) {
|
||||||
|
if (j < axes.size() && i == axes[j]) {
|
||||||
|
j++;
|
||||||
|
} else {
|
||||||
|
shape.push_back(input.shape(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Squeeze::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
return {Squeeze::output_shape(inputs[0], axes_)};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Tan::vjp(
|
std::vector<array> Tan::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
|
@ -983,6 +983,28 @@ class Expm1 : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ExpandDims : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit ExpandDims(Stream stream, std::vector<int> axes)
|
||||||
|
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(ExpandDims)
|
||||||
|
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
std::vector<int> axes_;
|
||||||
|
};
|
||||||
|
|
||||||
class FFT : public UnaryPrimitive {
|
class FFT : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit FFT(
|
explicit FFT(
|
||||||
@ -1046,9 +1068,11 @@ class Gather : public UnaryPrimitive {
|
|||||||
public:
|
public:
|
||||||
explicit Gather(
|
explicit Gather(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
const std::vector<int>& axes,
|
std::vector<int> axes,
|
||||||
const std::vector<int>& slice_sizes)
|
std::vector<int> slice_sizes)
|
||||||
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
|
: UnaryPrimitive(stream),
|
||||||
|
axes_(std::move(axes)),
|
||||||
|
slice_sizes_(std::move(slice_sizes)) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -2057,6 +2081,28 @@ class Subtract : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Squeeze : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit Squeeze(Stream stream, std::vector<int> axes)
|
||||||
|
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(Squeeze)
|
||||||
|
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
static Shape output_shape(const array& input, const std::vector<int>& axes);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
std::vector<int> axes_;
|
||||||
|
};
|
||||||
|
|
||||||
class Tan : public UnaryPrimitive {
|
class Tan : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
|
explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
|
||||||
|
@ -144,23 +144,23 @@ array mlx_gather_nd(
|
|||||||
int slice_index = 0;
|
int slice_index = 0;
|
||||||
for (int i = 0; i < gather_indices.size(); i++) {
|
for (int i = 0; i < gather_indices.size(); i++) {
|
||||||
if (is_slice[i]) {
|
if (is_slice[i]) {
|
||||||
std::vector<int> index_shape(max_dims + num_slices, 1);
|
Shape index_shape(max_dims + num_slices, 1);
|
||||||
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
||||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||||
slice_index++;
|
slice_index++;
|
||||||
} else {
|
} else {
|
||||||
std::vector<int> index_shape = gather_indices[i].shape();
|
auto index_shape = gather_indices[i].shape();
|
||||||
index_shape.insert(index_shape.end(), num_slices, 1);
|
index_shape.insert(index_shape.end(), num_slices, 1);
|
||||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// reshape them so that the int/array indices are last
|
// reshape them so that the int/array indices are last
|
||||||
for (int i = 0; i < gather_indices.size(); i++) {
|
for (int i = 0; i < gather_indices.size(); i++) {
|
||||||
if (i < num_slices) {
|
if (i < num_slices) {
|
||||||
std::vector<int> index_shape(max_dims + num_slices, 1);
|
Shape index_shape(max_dims + num_slices, 1);
|
||||||
index_shape[i] = gather_indices[i].shape(0);
|
index_shape[i] = gather_indices[i].shape(0);
|
||||||
gather_indices[i] = reshape(gather_indices[i], index_shape);
|
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -172,19 +172,11 @@ array mlx_gather_nd(
|
|||||||
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
||||||
src = gather(src, gather_indices, axes, slice_sizes);
|
src = gather(src, gather_indices, axes, slice_sizes);
|
||||||
|
|
||||||
// Squeeze the dims
|
// Squeeze the array index dims
|
||||||
std::vector<int> out_shape;
|
for (auto& ax : axes) {
|
||||||
out_shape.insert(
|
ax += max_dims + num_slices;
|
||||||
out_shape.end(),
|
}
|
||||||
src.shape().begin(),
|
return squeeze(src, axes);
|
||||||
src.shape().begin() + max_dims + num_slices);
|
|
||||||
out_shape.insert(
|
|
||||||
out_shape.end(),
|
|
||||||
src.shape().begin() + max_dims + num_slices + indices.size(),
|
|
||||||
src.shape().end());
|
|
||||||
src = reshape(src, out_shape);
|
|
||||||
|
|
||||||
return src;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mlx_expand_ellipsis(
|
auto mlx_expand_ellipsis(
|
||||||
|
@ -392,27 +392,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = fun(x, y=y, z=z)
|
out = fun(x, y=y, z=z)
|
||||||
self.assertEqual(out.item(), 6)
|
self.assertEqual(out.item(), 6)
|
||||||
|
|
||||||
def test_shapeless_compile(self):
|
|
||||||
y = 1
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
|
||||||
def fun(x):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
x = mx.array([1, 2])
|
|
||||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
|
||||||
|
|
||||||
# The function is not recompiled, so the change
|
|
||||||
# to y should not be reflected in the output
|
|
||||||
y = 2
|
|
||||||
x = mx.array([1, 2, 3])
|
|
||||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
|
||||||
|
|
||||||
# Type change recompiles
|
|
||||||
x = mx.array([1.0, 2.0, 3.0])
|
|
||||||
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
|
||||||
fun(x, y=y, z=z)
|
|
||||||
|
|
||||||
def test_shapeless_compile(self):
|
def test_shapeless_compile(self):
|
||||||
y = 1
|
y = 1
|
||||||
|
|
||||||
@ -477,6 +456,12 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
mx.eval(cfun(x1))
|
mx.eval(cfun(x1))
|
||||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return x * x.sum(-1, keepdims=False)
|
||||||
|
|
||||||
|
cfun = mx.compile(fun, shapeless=True)
|
||||||
|
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||||
|
|
||||||
def test_compile_with_constant(self):
|
def test_compile_with_constant(self):
|
||||||
# Test float
|
# Test float
|
||||||
@partial(mx.compile)
|
@partial(mx.compile)
|
||||||
@ -809,6 +794,13 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = fun(*inputs)
|
out = fun(*inputs)
|
||||||
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||||
|
|
||||||
|
def test_shapeless_compile_matmul(self):
|
||||||
|
a = mx.array([0.0, 1.0, 2.0])
|
||||||
|
b = mx.array([0.0, 1.0, 2.0])
|
||||||
|
|
||||||
|
fun = mx.compile(lambda a, b: a @ b, shapeless=True)
|
||||||
|
self.assertTrue(mx.allclose(fun(a, b), a @ b))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1835,6 +1835,9 @@ TEST_CASE("test broadcast") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test gather") {
|
TEST_CASE("test gather") {
|
||||||
|
// Empty input, non-empty indices/slice
|
||||||
|
CHECK_THROWS(gather(array({}), array({1}), 0, {1}));
|
||||||
|
|
||||||
// More indices than dimensions
|
// More indices than dimensions
|
||||||
CHECK_THROWS(gather(array(0), array({1}), 0, {1}));
|
CHECK_THROWS(gather(array(0), array({1}), 0, {1}));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user