|
|
|
@ -88,7 +88,7 @@ array arange(
|
|
|
|
|
return array(
|
|
|
|
|
{size},
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Arange>(to_stream(s), start, stop, step),
|
|
|
|
|
std::make_shared<Arange>(to_stream(s), start, stop, step),
|
|
|
|
|
{});
|
|
|
|
|
}
|
|
|
|
|
array arange(
|
|
|
|
@ -163,7 +163,7 @@ array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return a;
|
|
|
|
|
}
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<AsType>(to_stream(s), dtype), {a});
|
|
|
|
|
a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array as_strided(
|
|
|
|
@ -177,12 +177,12 @@ array as_strided(
|
|
|
|
|
return array(
|
|
|
|
|
shape,
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<AsStrided>(to_stream(s), shape, strides, offset),
|
|
|
|
|
std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
|
|
|
|
|
{x});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array copy(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(a.shape(), a.dtype(), std::make_unique<Copy>(to_stream(s)), {a});
|
|
|
|
|
return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array full(
|
|
|
|
@ -194,7 +194,7 @@ array full(
|
|
|
|
|
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
|
|
|
|
}
|
|
|
|
|
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
|
|
|
|
|
return array(shape, dtype, std::make_unique<Full>(to_stream(s)), {in});
|
|
|
|
|
return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array full(
|
|
|
|
@ -313,9 +313,8 @@ array reshape(
|
|
|
|
|
<< " into shape " << shape << ".";
|
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return array(
|
|
|
|
|
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
|
|
|
|
|
auto p = std::make_shared<Reshape>(to_stream(s), shape);
|
|
|
|
|
return array(std::move(shape), a.dtype(), std::move(p), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array flatten(
|
|
|
|
@ -408,6 +407,20 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return squeeze(a, axes, s);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
|
|
|
|
int out_dim = a.ndim() + 1;
|
|
|
|
|
int ax = axis < 0 ? axis + out_dim : axis;
|
|
|
|
|
if (ax < 0 || ax >= out_dim) {
|
|
|
|
|
std::ostringstream msg;
|
|
|
|
|
msg << "[expand_dims] Invalid axes " << axis << " for output array with "
|
|
|
|
|
<< a.ndim() << " dimensions.";
|
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
|
}
|
|
|
|
|
auto shape = a.shape();
|
|
|
|
|
shape.insert(shape.begin() + ax, 1);
|
|
|
|
|
return reshape(a, std::move(shape), s);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array expand_dims(
|
|
|
|
|
const array& a,
|
|
|
|
|
const std::vector<int>& axes,
|
|
|
|
@ -425,7 +438,7 @@ array expand_dims(
|
|
|
|
|
ax = ax < 0 ? ax + out_ndim : ax;
|
|
|
|
|
if (ax < 0 || ax >= out_ndim) {
|
|
|
|
|
std::ostringstream msg;
|
|
|
|
|
msg << "[squeeze] Invalid axes " << ax << " for output array with "
|
|
|
|
|
msg << "[expand_dims] Invalid axes " << ax << " for output array with "
|
|
|
|
|
<< a.ndim() << " dimensions.";
|
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
|
}
|
|
|
|
@ -442,7 +455,7 @@ array expand_dims(
|
|
|
|
|
for (int i = 0; i < sorted_axes.size(); ++i) {
|
|
|
|
|
out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
|
|
|
|
|
}
|
|
|
|
|
return reshape(a, out_shape, s);
|
|
|
|
|
return reshape(a, std::move(out_shape), s);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Slice helper
|
|
|
|
@ -523,7 +536,7 @@ array slice(
|
|
|
|
|
return array(
|
|
|
|
|
out_shape,
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Slice>(
|
|
|
|
|
std::make_shared<Slice>(
|
|
|
|
|
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
@ -568,7 +581,7 @@ array slice_update(
|
|
|
|
|
return array(
|
|
|
|
|
src.shape(),
|
|
|
|
|
src.dtype(),
|
|
|
|
|
std::make_unique<SliceUpdate>(
|
|
|
|
|
std::make_shared<SliceUpdate>(
|
|
|
|
|
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
|
|
|
|
{src, update_broadcasted});
|
|
|
|
|
}
|
|
|
|
@ -743,7 +756,10 @@ array concatenate(
|
|
|
|
|
auto dtype = result_type(arrays);
|
|
|
|
|
|
|
|
|
|
return array(
|
|
|
|
|
shape, dtype, std::make_unique<Concatenate>(to_stream(s), ax), arrays);
|
|
|
|
|
std::move(shape),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_shared<Concatenate>(to_stream(s), ax),
|
|
|
|
|
std::move(arrays));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array concatenate(
|
|
|
|
@ -886,7 +902,7 @@ array pad(
|
|
|
|
|
return array(
|
|
|
|
|
out_shape,
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
|
|
|
|
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
|
|
|
|
{a, astype(pad_value, a.dtype(), s)});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -975,7 +991,7 @@ array swapaxes(
|
|
|
|
|
std::vector<int> reorder(a.ndim());
|
|
|
|
|
std::iota(reorder.begin(), reorder.end(), 0);
|
|
|
|
|
std::swap(reorder[axis1], reorder[axis2]);
|
|
|
|
|
return transpose(a, reorder, s);
|
|
|
|
|
return transpose(a, std::move(reorder), s);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array transpose(
|
|
|
|
@ -1000,9 +1016,9 @@ array transpose(
|
|
|
|
|
shape.push_back(a.shape()[ax]);
|
|
|
|
|
}
|
|
|
|
|
return array(
|
|
|
|
|
shape,
|
|
|
|
|
std::move(shape),
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Transpose>(to_stream(s), std::move(axes)),
|
|
|
|
|
std::make_shared<Transpose>(to_stream(s), std::move(axes)),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1029,7 +1045,16 @@ array broadcast_to(
|
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
|
}
|
|
|
|
|
return array(
|
|
|
|
|
shape, a.dtype(), std::make_unique<Broadcast>(to_stream(s), shape), {a});
|
|
|
|
|
std::move(bxshape),
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_shared<Broadcast>(to_stream(s), shape),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<array>
|
|
|
|
|
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
std::vector<int> shape = broadcast_shapes(a.shape(), b.shape());
|
|
|
|
|
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<array> broadcast_arrays(
|
|
|
|
@ -1048,38 +1073,29 @@ std::vector<array> broadcast_arrays(
|
|
|
|
|
|
|
|
|
|
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
|
|
|
|
if (a.shape() != b.shape()) {
|
|
|
|
|
inputs = broadcast_arrays(inputs, s);
|
|
|
|
|
}
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(), bool_, std::make_unique<Equal>(to_stream(s)), inputs);
|
|
|
|
|
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
|
|
|
|
if (a.shape() != b.shape()) {
|
|
|
|
|
inputs = broadcast_arrays(inputs, s);
|
|
|
|
|
}
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<NotEqual>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<NotEqual>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
|
|
|
|
if (a.shape() != b.shape()) {
|
|
|
|
|
inputs = broadcast_arrays(inputs, s);
|
|
|
|
|
}
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<Greater>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array greater_equal(
|
|
|
|
@ -1087,38 +1103,32 @@ array greater_equal(
|
|
|
|
|
const array& b,
|
|
|
|
|
StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
|
|
|
|
if (a.shape() != b.shape()) {
|
|
|
|
|
inputs = broadcast_arrays(inputs, s);
|
|
|
|
|
}
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<GreaterEqual>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<GreaterEqual>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
|
|
|
|
if (a.shape() != b.shape()) {
|
|
|
|
|
inputs = broadcast_arrays(inputs, s);
|
|
|
|
|
}
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(), bool_, std::make_unique<Less>(to_stream(s)), inputs);
|
|
|
|
|
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
|
|
|
|
if (a.shape() != b.shape()) {
|
|
|
|
|
inputs = broadcast_arrays(inputs, s);
|
|
|
|
|
}
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<LessEqual>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<LessEqual>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array array_equal(
|
|
|
|
@ -1135,7 +1145,7 @@ array array_equal(
|
|
|
|
|
array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<Equal>(to_stream(s), equal_nan),
|
|
|
|
|
std::make_shared<Equal>(to_stream(s), equal_nan),
|
|
|
|
|
{astype(a, dtype, s), astype(b, dtype, s)}),
|
|
|
|
|
false,
|
|
|
|
|
s);
|
|
|
|
@ -1180,7 +1190,7 @@ array where(
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
out_dtype,
|
|
|
|
|
std::make_unique<Select>(to_stream(s)),
|
|
|
|
|
std::make_shared<Select>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1246,7 +1256,7 @@ array all(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
|
|
|
|
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
|
|
|
|
{a});
|
|
|
|
|
if (!keepdims) {
|
|
|
|
|
out = squeeze(out, sorted_axes, s);
|
|
|
|
@ -1280,7 +1290,7 @@ array any(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
|
|
|
|
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
|
|
|
|
{a});
|
|
|
|
|
if (!keepdims) {
|
|
|
|
|
out = squeeze(out, sorted_axes, s);
|
|
|
|
@ -1315,7 +1325,7 @@ array sum(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
|
|
|
|
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
|
|
|
|
{a});
|
|
|
|
|
if (!keepdims) {
|
|
|
|
|
out = squeeze(out, sorted_axes, s);
|
|
|
|
@ -1424,7 +1434,7 @@ array prod(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
|
|
|
|
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
|
|
|
|
{a});
|
|
|
|
|
if (!keepdims) {
|
|
|
|
|
out = squeeze(out, sorted_axes, s);
|
|
|
|
@ -1461,7 +1471,7 @@ array max(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
|
|
|
|
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
|
|
|
|
{a});
|
|
|
|
|
if (!keepdims) {
|
|
|
|
|
out = squeeze(out, sorted_axes, s);
|
|
|
|
@ -1498,7 +1508,7 @@ array min(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
|
|
|
|
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
|
|
|
|
{a});
|
|
|
|
|
if (!keepdims) {
|
|
|
|
|
out = squeeze(out, sorted_axes, s);
|
|
|
|
@ -1538,7 +1548,7 @@ array argmin(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
uint32,
|
|
|
|
|
std::make_unique<ArgReduce>(
|
|
|
|
|
std::make_shared<ArgReduce>(
|
|
|
|
|
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
|
|
|
|
{a});
|
|
|
|
|
if (!keepdims) {
|
|
|
|
@ -1571,7 +1581,7 @@ array argmax(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
uint32,
|
|
|
|
|
std::make_unique<ArgReduce>(
|
|
|
|
|
std::make_shared<ArgReduce>(
|
|
|
|
|
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
|
|
|
|
{a});
|
|
|
|
|
if (!keepdims) {
|
|
|
|
@ -1607,7 +1617,7 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), a.dtype(), std::make_unique<Sort>(to_stream(s), axis), {a});
|
|
|
|
|
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/** Returns indices that sort the flattened array. */
|
|
|
|
@ -1637,7 +1647,7 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), uint32, std::make_unique<ArgSort>(to_stream(s), axis), {a});
|
|
|
|
|
a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -1677,7 +1687,7 @@ array partition(
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Partition>(to_stream(s), kth_, axis_),
|
|
|
|
|
std::make_shared<Partition>(to_stream(s), kth_, axis_),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1718,7 +1728,7 @@ array argpartition(
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
uint32,
|
|
|
|
|
std::make_unique<ArgPartition>(to_stream(s), kth_, axis_),
|
|
|
|
|
std::make_shared<ArgPartition>(to_stream(s), kth_, axis_),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1787,7 +1797,7 @@ array logsumexp(
|
|
|
|
|
|
|
|
|
|
array abs(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto out =
|
|
|
|
|
array(a.shape(), a.dtype(), std::make_unique<Abs>(to_stream(s)), {a});
|
|
|
|
|
array(a.shape(), a.dtype(), std::make_shared<Abs>(to_stream(s)), {a});
|
|
|
|
|
if (a.dtype() == complex64) {
|
|
|
|
|
out = astype(out, float32, s);
|
|
|
|
|
}
|
|
|
|
@ -1800,33 +1810,33 @@ array negative(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
throw std::invalid_argument(msg);
|
|
|
|
|
}
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), a.dtype(), std::make_unique<Negative>(to_stream(s)), {a});
|
|
|
|
|
a.shape(), a.dtype(), std::make_shared<Negative>(to_stream(s)), {a});
|
|
|
|
|
}
|
|
|
|
|
array operator-(const array& a) {
|
|
|
|
|
return negative(a);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array sign(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(a.shape(), a.dtype(), std::make_unique<Sign>(to_stream(s)), {a});
|
|
|
|
|
return array(a.shape(), a.dtype(), std::make_shared<Sign>(to_stream(s)), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array logical_not(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<LogicalNot>(to_stream(s)),
|
|
|
|
|
std::make_shared<LogicalNot>(to_stream(s)),
|
|
|
|
|
{astype(a, bool_, s)});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
// Broadcast arrays to a common shape
|
|
|
|
|
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
|
|
|
|
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<LogicalAnd>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<LogicalAnd>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
array operator&&(const array& a, const array& b) {
|
|
|
|
|
return logical_and(a, b);
|
|
|
|
@ -1834,13 +1844,13 @@ array operator&&(const array& a, const array& b) {
|
|
|
|
|
|
|
|
|
|
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
// Broadcast arrays to a common shape
|
|
|
|
|
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
|
|
|
|
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
bool_,
|
|
|
|
|
std::make_unique<LogicalOr>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<LogicalOr>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
array operator||(const array& a, const array& b) {
|
|
|
|
|
return logical_or(a, b);
|
|
|
|
@ -1854,9 +1864,10 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
|
|
|
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(), out_type, std::make_unique<Add>(to_stream(s)), inputs);
|
|
|
|
|
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array operator+(const array& a, const array& b) {
|
|
|
|
@ -1866,12 +1877,13 @@ array operator+(const array& a, const array& b) {
|
|
|
|
|
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
|
|
|
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<Subtract>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<Subtract>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array operator-(const array& a, const array& b) {
|
|
|
|
@ -1881,12 +1893,13 @@ array operator-(const array& a, const array& b) {
|
|
|
|
|
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
|
|
|
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<Multiply>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<Multiply>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array operator*(const array& a, const array& b) {
|
|
|
|
@ -1895,10 +1908,11 @@ array operator*(const array& a, const array& b) {
|
|
|
|
|
|
|
|
|
|
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
|
|
|
|
auto inputs = broadcast_arrays(
|
|
|
|
|
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
|
|
|
|
|
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
array operator/(const array& a, const array& b) {
|
|
|
|
|
return divide(a, b);
|
|
|
|
@ -1919,20 +1933,22 @@ array floor_divide(
|
|
|
|
|
return floor(divide(a, b, s), s);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
|
|
|
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
|
|
|
|
|
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
auto inputs = broadcast_arrays(
|
|
|
|
|
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Remainder>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<Remainder>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
array operator%(const array& a, const array& b) {
|
|
|
|
|
return remainder(a, b);
|
|
|
|
@ -1944,8 +1960,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
if (is_complex(dtype)) {
|
|
|
|
|
throw std::invalid_argument("[divmod] Complex type not supported.");
|
|
|
|
|
}
|
|
|
|
|
auto inputs = broadcast_arrays(
|
|
|
|
|
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
|
|
|
|
return array::make_arrays(
|
|
|
|
|
{inputs[0].shape(), inputs[0].shape()},
|
|
|
|
|
{inputs[0].dtype(), inputs[0].dtype()},
|
|
|
|
@ -1956,23 +1972,25 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
|
|
|
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<Maximum>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<Maximum>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
|
|
|
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<Minimum>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<Minimum>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array floor(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
@ -1980,103 +1998,103 @@ array floor(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
throw std::invalid_argument("[floor] Not supported for complex64.");
|
|
|
|
|
}
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), a.dtype(), std::make_unique<Floor>(to_stream(s)), {a});
|
|
|
|
|
a.shape(), a.dtype(), std::make_shared<Floor>(to_stream(s)), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array ceil(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
if (a.dtype() == complex64) {
|
|
|
|
|
throw std::invalid_argument("[floor] Not supported for complex64.");
|
|
|
|
|
}
|
|
|
|
|
return array(a.shape(), a.dtype(), std::make_unique<Ceil>(to_stream(s)), {a});
|
|
|
|
|
return array(a.shape(), a.dtype(), std::make_shared<Ceil>(to_stream(s)), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array square(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), a.dtype(), std::make_unique<Square>(to_stream(s)), {a});
|
|
|
|
|
a.shape(), a.dtype(), std::make_shared<Square>(to_stream(s)), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array exp(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(a.shape(), dtype, std::make_unique<Exp>(to_stream(s)), {input});
|
|
|
|
|
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array sin(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(a.shape(), dtype, std::make_unique<Sin>(to_stream(s)), {input});
|
|
|
|
|
return array(a.shape(), dtype, std::make_shared<Sin>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array cos(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(a.shape(), dtype, std::make_unique<Cos>(to_stream(s)), {input});
|
|
|
|
|
return array(a.shape(), dtype, std::make_shared<Cos>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array tan(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(a.shape(), dtype, std::make_unique<Tan>(to_stream(s)), {input});
|
|
|
|
|
return array(a.shape(), dtype, std::make_shared<Tan>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array arcsin(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<ArcSin>(to_stream(s)), {input});
|
|
|
|
|
a.shape(), dtype, std::make_shared<ArcSin>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array arccos(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<ArcCos>(to_stream(s)), {input});
|
|
|
|
|
a.shape(), dtype, std::make_shared<ArcCos>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<ArcTan>(to_stream(s)), {input});
|
|
|
|
|
a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array sinh(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(a.shape(), dtype, std::make_unique<Sinh>(to_stream(s)), {input});
|
|
|
|
|
return array(a.shape(), dtype, std::make_shared<Sinh>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array cosh(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(a.shape(), dtype, std::make_unique<Cosh>(to_stream(s)), {input});
|
|
|
|
|
return array(a.shape(), dtype, std::make_shared<Cosh>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array tanh(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(a.shape(), dtype, std::make_unique<Tanh>(to_stream(s)), {input});
|
|
|
|
|
return array(a.shape(), dtype, std::make_shared<Tanh>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array arcsinh(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<ArcSinh>(to_stream(s)), {input});
|
|
|
|
|
a.shape(), dtype, std::make_shared<ArcSinh>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array arccosh(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<ArcCosh>(to_stream(s)), {input});
|
|
|
|
|
a.shape(), dtype, std::make_shared<ArcCosh>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array arctanh(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<ArcTanh>(to_stream(s)), {input});
|
|
|
|
|
a.shape(), dtype, std::make_shared<ArcTanh>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array log(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
@ -2085,7 +2103,7 @@ array log(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Log>(to_stream(s), Log::Base::e),
|
|
|
|
|
std::make_shared<Log>(to_stream(s), Log::Base::e),
|
|
|
|
|
{input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2095,7 +2113,7 @@ array log2(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Log>(to_stream(s), Log::Base::two),
|
|
|
|
|
std::make_shared<Log>(to_stream(s), Log::Base::two),
|
|
|
|
|
{input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2105,7 +2123,7 @@ array log10(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Log>(to_stream(s), Log::Base::ten),
|
|
|
|
|
std::make_shared<Log>(to_stream(s), Log::Base::ten),
|
|
|
|
|
{input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2113,26 +2131,27 @@ array log1p(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<Log1p>(to_stream(s)), {input});
|
|
|
|
|
a.shape(), dtype, std::make_shared<Log1p>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
// Make sure out type is floating point
|
|
|
|
|
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
|
|
|
|
|
auto inputs =
|
|
|
|
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
|
|
|
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
|
|
|
|
auto& shape = inputs[0].shape();
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(),
|
|
|
|
|
shape,
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<LogAddExp>(to_stream(s)),
|
|
|
|
|
inputs);
|
|
|
|
|
std::make_shared<LogAddExp>(to_stream(s)),
|
|
|
|
|
std::move(inputs));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array sigmoid(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
|
auto input = astype(a, dtype, s);
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), dtype, std::make_unique<Sigmoid>(to_stream(s)), {input});
|
|
|
|
|
a.shape(), dtype, std::make_shared<Sigmoid>(to_stream(s)), {input});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array erf(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
@ -2140,7 +2159,7 @@ array erf(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Erf>(to_stream(s)),
|
|
|
|
|
std::make_shared<Erf>(to_stream(s)),
|
|
|
|
|
{astype(a, dtype, s)});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2149,19 +2168,19 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<ErfInv>(to_stream(s)),
|
|
|
|
|
std::make_shared<ErfInv>(to_stream(s)),
|
|
|
|
|
{astype(a, dtype, s)});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), a.dtype(), std::make_unique<StopGradient>(to_stream(s)), {a});
|
|
|
|
|
a.shape(), a.dtype(), std::make_shared<StopGradient>(to_stream(s)), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array round(const array& a, int decimals, StreamOrDevice s /* = {} */) {
|
|
|
|
|
if (decimals == 0) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(), a.dtype(), std::make_unique<Round>(to_stream(s)), {a});
|
|
|
|
|
a.shape(), a.dtype(), std::make_shared<Round>(to_stream(s)), {a});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dtype = at_least_float(a.dtype());
|
|
|
|
@ -2226,7 +2245,7 @@ array matmul(
|
|
|
|
|
auto out = array(
|
|
|
|
|
{a.shape(0), b.shape(1)},
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<Matmul>(to_stream(s)),
|
|
|
|
|
std::make_shared<Matmul>(to_stream(s)),
|
|
|
|
|
{a, b});
|
|
|
|
|
return reshape(out, out_shape, s);
|
|
|
|
|
}
|
|
|
|
@ -2250,17 +2269,17 @@ array matmul(
|
|
|
|
|
auto out_shape = a.shape();
|
|
|
|
|
out_shape.back() = b.shape(-1);
|
|
|
|
|
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape, out_type, std::make_unique<Matmul>(to_stream(s)), {a, b});
|
|
|
|
|
auto p = std::make_shared<Matmul>(to_stream(s));
|
|
|
|
|
|
|
|
|
|
// Remove the possibly inserted singleton dimensions
|
|
|
|
|
if (in_a.ndim() == 1 || in_b.ndim() == 1) {
|
|
|
|
|
auto out = array(out_shape, out_type, std::move(p), {a, b});
|
|
|
|
|
out_shape.erase(
|
|
|
|
|
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
|
|
|
|
|
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
|
|
|
|
|
out = reshape(out, out_shape, s);
|
|
|
|
|
return reshape(out, std::move(out_shape), s);
|
|
|
|
|
}
|
|
|
|
|
return out;
|
|
|
|
|
return array(std::move(out_shape), out_type, std::move(p), {a, b});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array gather(
|
|
|
|
@ -2332,7 +2351,7 @@ array gather(
|
|
|
|
|
return array(
|
|
|
|
|
out_shape,
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Gather>(to_stream(s), axes, slice_sizes),
|
|
|
|
|
std::make_shared<Gather>(to_stream(s), axes, slice_sizes),
|
|
|
|
|
inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2516,7 +2535,7 @@ array scatter(
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Scatter>(to_stream(s), mode, axes),
|
|
|
|
|
std::make_shared<Scatter>(to_stream(s), mode, axes),
|
|
|
|
|
inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2570,7 +2589,7 @@ array sqrt(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Sqrt>(to_stream(s)),
|
|
|
|
|
std::make_shared<Sqrt>(to_stream(s)),
|
|
|
|
|
{astype(a, dtype, s)});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2579,7 +2598,7 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Sqrt>(to_stream(s), true),
|
|
|
|
|
std::make_shared<Sqrt>(to_stream(s), true),
|
|
|
|
|
{astype(a, dtype, s)});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -2592,7 +2611,7 @@ array softmax(
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<Softmax>(to_stream(s)),
|
|
|
|
|
std::make_shared<Softmax>(to_stream(s)),
|
|
|
|
|
{astype(a, dtype, s)});
|
|
|
|
|
} else {
|
|
|
|
|
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
|
|
|
|
@ -2614,7 +2633,7 @@ array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|
|
|
|
inputs = broadcast_arrays(inputs, s);
|
|
|
|
|
}
|
|
|
|
|
return array(
|
|
|
|
|
inputs[0].shape(), dtype, std::make_unique<Power>(to_stream(s)), inputs);
|
|
|
|
|
inputs[0].shape(), dtype, std::make_shared<Power>(to_stream(s)), inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
array cumsum(
|
|
|
|
@ -2635,7 +2654,7 @@ array cumsum(
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<Scan>(
|
|
|
|
|
std::make_shared<Scan>(
|
|
|
|
|
to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
@ -2657,7 +2676,7 @@ array cumprod(
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Scan>(
|
|
|
|
|
std::make_shared<Scan>(
|
|
|
|
|
to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
@ -2679,7 +2698,7 @@ array cummax(
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Scan>(
|
|
|
|
|
std::make_shared<Scan>(
|
|
|
|
|
to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
@ -2701,7 +2720,7 @@ array cummin(
|
|
|
|
|
return array(
|
|
|
|
|
a.shape(),
|
|
|
|
|
a.dtype(),
|
|
|
|
|
std::make_unique<Scan>(
|
|
|
|
|
std::make_shared<Scan>(
|
|
|
|
|
to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive),
|
|
|
|
|
{a});
|
|
|
|
|
}
|
|
|
|
@ -2965,7 +2984,7 @@ array conv_general(
|
|
|
|
|
return array(
|
|
|
|
|
out_shape,
|
|
|
|
|
in.dtype(),
|
|
|
|
|
std::make_unique<Convolution>(
|
|
|
|
|
std::make_shared<Convolution>(
|
|
|
|
|
to_stream(s),
|
|
|
|
|
stride,
|
|
|
|
|
padding_lo,
|
|
|
|
@ -3042,7 +3061,7 @@ array quantized_matmul(
|
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dtype = result_type({x, scales, biases});
|
|
|
|
|
auto dtype = result_type(x, scales, biases);
|
|
|
|
|
if (!is_floating_point(dtype) || is_complex(dtype)) {
|
|
|
|
|
std::ostringstream msg;
|
|
|
|
|
msg << "[quantized_matmul] Only real floating types are supported but "
|
|
|
|
@ -3055,7 +3074,7 @@ array quantized_matmul(
|
|
|
|
|
auto out = array(
|
|
|
|
|
{x.shape(0), w_outer_dims},
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<QuantizedMatmul>(
|
|
|
|
|
std::make_shared<QuantizedMatmul>(
|
|
|
|
|
to_stream(s), group_size, bits, transpose),
|
|
|
|
|
{astype(x, dtype, s),
|
|
|
|
|
w,
|
|
|
|
@ -3065,7 +3084,7 @@ array quantized_matmul(
|
|
|
|
|
// If needed reshape x to the original batch shape
|
|
|
|
|
if (original_shape.size() != 1) {
|
|
|
|
|
original_shape.push_back(w_outer_dims);
|
|
|
|
|
out = reshape(out, original_shape, s);
|
|
|
|
|
out = reshape(out, std::move(original_shape), s);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return out;
|
|
|
|
@ -3344,7 +3363,7 @@ array addmm(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Type promotion
|
|
|
|
|
auto out_type = result_type({a, b, c});
|
|
|
|
|
auto out_type = result_type(a, b, c);
|
|
|
|
|
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
|
|
|
|
std::ostringstream msg;
|
|
|
|
|
msg << "[addmm] Only real floating point types are supported but "
|
|
|
|
@ -3373,7 +3392,7 @@ array addmm(
|
|
|
|
|
auto out = array(
|
|
|
|
|
{a.shape(0), b.shape(1)},
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<AddMM>(to_stream(s), alpha, beta),
|
|
|
|
|
std::make_shared<AddMM>(to_stream(s), alpha, beta),
|
|
|
|
|
{a, b, c});
|
|
|
|
|
return reshape(out, out_shape, s);
|
|
|
|
|
}
|
|
|
|
@ -3425,7 +3444,7 @@ array addmm(
|
|
|
|
|
auto out = array(
|
|
|
|
|
out_shape,
|
|
|
|
|
out_type,
|
|
|
|
|
std::make_unique<AddMM>(to_stream(s), alpha, beta),
|
|
|
|
|
std::make_shared<AddMM>(to_stream(s), alpha, beta),
|
|
|
|
|
{a, b, c});
|
|
|
|
|
|
|
|
|
|
// Remove the possibly inserted singleton dimensions
|
|
|
|
@ -3621,7 +3640,7 @@ array number_of_elements(
|
|
|
|
|
return stop_gradient(array(
|
|
|
|
|
std::vector<int>{},
|
|
|
|
|
dtype,
|
|
|
|
|
std::make_unique<NumberOfElements>(
|
|
|
|
|
std::make_shared<NumberOfElements>(
|
|
|
|
|
to_stream(s), std::move(axes), inverted, dtype),
|
|
|
|
|
{a}));
|
|
|
|
|
}
|
|
|
|
|