mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Reduce a little overhead (#871)
* some small overhead improvements * use result_type in rms_norm * remove release force * fix + use non-vector version * revert compile change * fix ops * a little more overhead * a little more cleanup and overhead
This commit is contained in:
parent
6ee1112f30
commit
be98f4ab6b
@ -12,16 +12,6 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
|
||||||
std::vector<size_t> strides(shape.size());
|
|
||||||
size_t cum_prod = 1;
|
|
||||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
|
||||||
strides[i] = cum_prod;
|
|
||||||
cum_prod *= shape[i];
|
|
||||||
}
|
|
||||||
return {cum_prod, strides};
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Return true if we are currently performing a function transformation in
|
/** Return true if we are currently performing a function transformation in
|
||||||
* order to keep the graph when evaluating tracer arrays. */
|
* order to keep the graph when evaluating tracer arrays. */
|
||||||
bool in_tracing() {
|
bool in_tracing() {
|
||||||
@ -171,9 +161,21 @@ void array::move_shared_buffer(array other) {
|
|||||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void array::ArrayDesc::init() {
|
||||||
|
strides.resize(shape.size());
|
||||||
|
size = 1;
|
||||||
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||||
|
strides[i] = size;
|
||||||
|
size *= shape[i];
|
||||||
|
}
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
is_tracer |= in.is_tracer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||||
: shape(std::move(shape)), dtype(dtype) {
|
: shape(std::move(shape)), dtype(dtype) {
|
||||||
std::tie(size, strides) = cum_prod(this->shape);
|
init();
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(
|
array::ArrayDesc::ArrayDesc(
|
||||||
@ -185,10 +187,7 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
inputs(std::move(inputs)) {
|
inputs(std::move(inputs)) {
|
||||||
std::tie(size, strides) = cum_prod(this->shape);
|
init();
|
||||||
for (auto& in : this->inputs) {
|
|
||||||
is_tracer |= in.is_tracer();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||||
|
@ -392,6 +392,10 @@ class array {
|
|||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
std::vector<array> inputs);
|
std::vector<array> inputs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Initialize size, strides, and other metadata
|
||||||
|
void init();
|
||||||
};
|
};
|
||||||
|
|
||||||
// The ArrayDesc contains the details of the materialized array including the
|
// The ArrayDesc contains the details of the materialized array including the
|
||||||
|
23
mlx/fast.cpp
23
mlx/fast.cpp
@ -63,7 +63,7 @@ array rms_norm(
|
|||||||
<< " dimensions.";
|
<< " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
auto out_type = result_type({x, weight});
|
auto out_type = result_type(x, weight);
|
||||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||||
@ -88,7 +88,7 @@ array rms_norm(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<RMSNorm>(s, fallback, eps),
|
std::make_shared<RMSNorm>(s, fallback, eps),
|
||||||
{astype(x, out_type, s), astype(weight, out_type, s)});
|
{astype(x, out_type, s), astype(weight, out_type, s)});
|
||||||
}
|
}
|
||||||
return fallback({x, weight})[0];
|
return fallback({x, weight})[0];
|
||||||
@ -125,8 +125,8 @@ array layer_norm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto out_type = (weight.has_value())
|
auto out_type = (weight.has_value())
|
||||||
? ((bias.has_value()) ? result_type({x, *weight, *bias})
|
? ((bias.has_value()) ? result_type(x, *weight, *bias)
|
||||||
: result_type({x, *weight}))
|
: result_type(x, *weight))
|
||||||
: x.dtype();
|
: x.dtype();
|
||||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -170,7 +170,7 @@ array layer_norm(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<LayerNorm>(s, fallback, eps),
|
std::make_shared<LayerNorm>(s, fallback, eps),
|
||||||
{astype(x, out_type, s), passed_weight, passed_bias});
|
{astype(x, out_type, s), passed_weight, passed_bias});
|
||||||
}
|
}
|
||||||
return fallback({x, passed_weight, passed_bias})[0];
|
return fallback({x, passed_weight, passed_bias})[0];
|
||||||
@ -248,7 +248,7 @@ array rope(
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_unique<RoPE>(
|
std::make_shared<RoPE>(
|
||||||
stream, fallback, dims, traditional, base, scale, offset),
|
stream, fallback, dims, traditional, base, scale, offset),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
@ -318,7 +318,7 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto final_type = result_type({queries, keys, values});
|
auto final_type = result_type(queries, keys, values);
|
||||||
if (!is_floating_point(final_type) || is_complex(final_type)) {
|
if (!is_floating_point(final_type) || is_complex(final_type)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] Received unsupported type "
|
msg << "[scaled_dot_product_attention] Received unsupported type "
|
||||||
@ -330,9 +330,6 @@ array scaled_dot_product_attention(
|
|||||||
auto k = astype(keys, final_type, s);
|
auto k = astype(keys, final_type, s);
|
||||||
auto v = astype(values, final_type, s);
|
auto v = astype(values, final_type, s);
|
||||||
|
|
||||||
auto out_shape =
|
|
||||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
|
||||||
|
|
||||||
/* generic implementation for use cases that Metal implementation does not
|
/* generic implementation for use cases that Metal implementation does not
|
||||||
* support. For non-supported cases listed below, use MLX primitives:
|
* support. For non-supported cases listed below, use MLX primitives:
|
||||||
* * CPU implementation
|
* * CPU implementation
|
||||||
@ -381,10 +378,12 @@ array scaled_dot_product_attention(
|
|||||||
// TODO, update routing conditions post further tuning
|
// TODO, update routing conditions post further tuning
|
||||||
implementation_supports_use_case &= false;
|
implementation_supports_use_case &= false;
|
||||||
if (implementation_supports_use_case) {
|
if (implementation_supports_use_case) {
|
||||||
|
auto out_shape =
|
||||||
|
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
std::move(out_shape),
|
||||||
final_type,
|
final_type,
|
||||||
std::make_unique<ScaledDotProductAttention>(
|
std::make_shared<ScaledDotProductAttention>(
|
||||||
stream, fallback, scale, false),
|
stream, fallback, scale, false),
|
||||||
{q, k, v});
|
{q, k, v});
|
||||||
return out;
|
return out;
|
||||||
|
@ -195,7 +195,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
auto out = array::make_arrays(
|
auto out = array::make_arrays(
|
||||||
{a.shape(), a.shape()},
|
{a.shape(), a.shape()},
|
||||||
{a.dtype(), a.dtype()},
|
{a.dtype(), a.dtype()},
|
||||||
std::make_unique<QRF>(to_stream(s)),
|
std::make_shared<QRF>(to_stream(s)),
|
||||||
{astype(a, a.dtype(), s)});
|
{astype(a, a.dtype(), s)});
|
||||||
return std::make_pair(out[0], out[1]);
|
return std::make_pair(out[0], out[1]);
|
||||||
}
|
}
|
||||||
@ -234,7 +234,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
{u_shape, s_shape, vt_shape},
|
{u_shape, s_shape, vt_shape},
|
||||||
{a.dtype(), a.dtype(), a.dtype()},
|
{a.dtype(), a.dtype(), a.dtype()},
|
||||||
std::make_unique<SVD>(to_stream(s)),
|
std::make_shared<SVD>(to_stream(s)),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,7 +258,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
a.shape(), a.dtype(), std::make_unique<Inverse>(to_stream(s)), {a});
|
a.shape(), a.dtype(), std::make_shared<Inverse>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::linalg
|
} // namespace mlx::core::linalg
|
||||||
|
353
mlx/ops.cpp
353
mlx/ops.cpp
@ -88,7 +88,7 @@ array arange(
|
|||||||
return array(
|
return array(
|
||||||
{size},
|
{size},
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Arange>(to_stream(s), start, stop, step),
|
std::make_shared<Arange>(to_stream(s), start, stop, step),
|
||||||
{});
|
{});
|
||||||
}
|
}
|
||||||
array arange(
|
array arange(
|
||||||
@ -163,7 +163,7 @@ array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
|||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<AsType>(to_stream(s), dtype), {a});
|
a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
array as_strided(
|
array as_strided(
|
||||||
@ -177,12 +177,12 @@ array as_strided(
|
|||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<AsStrided>(to_stream(s), shape, strides, offset),
|
std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
array copy(const array& a, StreamOrDevice s /* = {} */) {
|
array copy(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
return array(a.shape(), a.dtype(), std::make_unique<Copy>(to_stream(s)), {a});
|
return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
array full(
|
array full(
|
||||||
@ -194,7 +194,7 @@ array full(
|
|||||||
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
||||||
}
|
}
|
||||||
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
|
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
|
||||||
return array(shape, dtype, std::make_unique<Full>(to_stream(s)), {in});
|
return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
|
||||||
}
|
}
|
||||||
|
|
||||||
array full(
|
array full(
|
||||||
@ -313,9 +313,8 @@ array reshape(
|
|||||||
<< " into shape " << shape << ".";
|
<< " into shape " << shape << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
auto p = std::make_shared<Reshape>(to_stream(s), shape);
|
||||||
return array(
|
return array(std::move(shape), a.dtype(), std::move(p), {a});
|
||||||
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array flatten(
|
array flatten(
|
||||||
@ -408,6 +407,20 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return squeeze(a, axes, s);
|
return squeeze(a, axes, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
||||||
|
int out_dim = a.ndim() + 1;
|
||||||
|
int ax = axis < 0 ? axis + out_dim : axis;
|
||||||
|
if (ax < 0 || ax >= out_dim) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[expand_dims] Invalid axes " << axis << " for output array with "
|
||||||
|
<< a.ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
auto shape = a.shape();
|
||||||
|
shape.insert(shape.begin() + ax, 1);
|
||||||
|
return reshape(a, std::move(shape), s);
|
||||||
|
}
|
||||||
|
|
||||||
array expand_dims(
|
array expand_dims(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
@ -425,7 +438,7 @@ array expand_dims(
|
|||||||
ax = ax < 0 ? ax + out_ndim : ax;
|
ax = ax < 0 ? ax + out_ndim : ax;
|
||||||
if (ax < 0 || ax >= out_ndim) {
|
if (ax < 0 || ax >= out_ndim) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[squeeze] Invalid axes " << ax << " for output array with "
|
msg << "[expand_dims] Invalid axes " << ax << " for output array with "
|
||||||
<< a.ndim() << " dimensions.";
|
<< a.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
@ -442,7 +455,7 @@ array expand_dims(
|
|||||||
for (int i = 0; i < sorted_axes.size(); ++i) {
|
for (int i = 0; i < sorted_axes.size(); ++i) {
|
||||||
out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
|
out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
|
||||||
}
|
}
|
||||||
return reshape(a, out_shape, s);
|
return reshape(a, std::move(out_shape), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slice helper
|
// Slice helper
|
||||||
@ -523,7 +536,7 @@ array slice(
|
|||||||
return array(
|
return array(
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Slice>(
|
std::make_shared<Slice>(
|
||||||
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
@ -568,7 +581,7 @@ array slice_update(
|
|||||||
return array(
|
return array(
|
||||||
src.shape(),
|
src.shape(),
|
||||||
src.dtype(),
|
src.dtype(),
|
||||||
std::make_unique<SliceUpdate>(
|
std::make_shared<SliceUpdate>(
|
||||||
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
||||||
{src, update_broadcasted});
|
{src, update_broadcasted});
|
||||||
}
|
}
|
||||||
@ -743,7 +756,10 @@ array concatenate(
|
|||||||
auto dtype = result_type(arrays);
|
auto dtype = result_type(arrays);
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
shape, dtype, std::make_unique<Concatenate>(to_stream(s), ax), arrays);
|
std::move(shape),
|
||||||
|
dtype,
|
||||||
|
std::make_shared<Concatenate>(to_stream(s), ax),
|
||||||
|
std::move(arrays));
|
||||||
}
|
}
|
||||||
|
|
||||||
array concatenate(
|
array concatenate(
|
||||||
@ -886,7 +902,7 @@ array pad(
|
|||||||
return array(
|
return array(
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
||||||
{a, astype(pad_value, a.dtype(), s)});
|
{a, astype(pad_value, a.dtype(), s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -975,7 +991,7 @@ array swapaxes(
|
|||||||
std::vector<int> reorder(a.ndim());
|
std::vector<int> reorder(a.ndim());
|
||||||
std::iota(reorder.begin(), reorder.end(), 0);
|
std::iota(reorder.begin(), reorder.end(), 0);
|
||||||
std::swap(reorder[axis1], reorder[axis2]);
|
std::swap(reorder[axis1], reorder[axis2]);
|
||||||
return transpose(a, reorder, s);
|
return transpose(a, std::move(reorder), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array transpose(
|
array transpose(
|
||||||
@ -1000,9 +1016,9 @@ array transpose(
|
|||||||
shape.push_back(a.shape()[ax]);
|
shape.push_back(a.shape()[ax]);
|
||||||
}
|
}
|
||||||
return array(
|
return array(
|
||||||
shape,
|
std::move(shape),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Transpose>(to_stream(s), std::move(axes)),
|
std::make_shared<Transpose>(to_stream(s), std::move(axes)),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1029,7 +1045,16 @@ array broadcast_to(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
return array(
|
return array(
|
||||||
shape, a.dtype(), std::make_unique<Broadcast>(to_stream(s), shape), {a});
|
std::move(bxshape),
|
||||||
|
a.dtype(),
|
||||||
|
std::make_shared<Broadcast>(to_stream(s), shape),
|
||||||
|
{a});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array>
|
||||||
|
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
std::vector<int> shape = broadcast_shapes(a.shape(), b.shape());
|
||||||
|
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> broadcast_arrays(
|
std::vector<array> broadcast_arrays(
|
||||||
@ -1048,38 +1073,29 @@ std::vector<array> broadcast_arrays(
|
|||||||
|
|
||||||
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||||
if (a.shape() != b.shape()) {
|
auto& shape = inputs[0].shape();
|
||||||
inputs = broadcast_arrays(inputs, s);
|
|
||||||
}
|
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(), bool_, std::make_unique<Equal>(to_stream(s)), inputs);
|
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||||
if (a.shape() != b.shape()) {
|
auto& shape = inputs[0].shape();
|
||||||
inputs = broadcast_arrays(inputs, s);
|
|
||||||
}
|
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<NotEqual>(to_stream(s)),
|
std::make_shared<NotEqual>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||||
if (a.shape() != b.shape()) {
|
auto& shape = inputs[0].shape();
|
||||||
inputs = broadcast_arrays(inputs, s);
|
|
||||||
}
|
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
|
||||||
bool_,
|
|
||||||
std::make_unique<Greater>(to_stream(s)),
|
|
||||||
inputs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array greater_equal(
|
array greater_equal(
|
||||||
@ -1087,38 +1103,32 @@ array greater_equal(
|
|||||||
const array& b,
|
const array& b,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||||
if (a.shape() != b.shape()) {
|
auto& shape = inputs[0].shape();
|
||||||
inputs = broadcast_arrays(inputs, s);
|
|
||||||
}
|
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<GreaterEqual>(to_stream(s)),
|
std::make_shared<GreaterEqual>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||||
if (a.shape() != b.shape()) {
|
auto& shape = inputs[0].shape();
|
||||||
inputs = broadcast_arrays(inputs, s);
|
|
||||||
}
|
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(), bool_, std::make_unique<Less>(to_stream(s)), inputs);
|
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||||
if (a.shape() != b.shape()) {
|
auto& shape = inputs[0].shape();
|
||||||
inputs = broadcast_arrays(inputs, s);
|
|
||||||
}
|
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<LessEqual>(to_stream(s)),
|
std::make_shared<LessEqual>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array array_equal(
|
array array_equal(
|
||||||
@ -1135,7 +1145,7 @@ array array_equal(
|
|||||||
array(
|
array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<Equal>(to_stream(s), equal_nan),
|
std::make_shared<Equal>(to_stream(s), equal_nan),
|
||||||
{astype(a, dtype, s), astype(b, dtype, s)}),
|
{astype(a, dtype, s), astype(b, dtype, s)}),
|
||||||
false,
|
false,
|
||||||
s);
|
s);
|
||||||
@ -1180,7 +1190,7 @@ array where(
|
|||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
inputs[0].shape(),
|
||||||
out_dtype,
|
out_dtype,
|
||||||
std::make_unique<Select>(to_stream(s)),
|
std::make_shared<Select>(to_stream(s)),
|
||||||
inputs);
|
inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1246,7 +1256,7 @@ array all(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = squeeze(out, sorted_axes, s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
@ -1280,7 +1290,7 @@ array any(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = squeeze(out, sorted_axes, s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
@ -1315,7 +1325,7 @@ array sum(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = squeeze(out, sorted_axes, s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
@ -1424,7 +1434,7 @@ array prod(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = squeeze(out, sorted_axes, s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
@ -1461,7 +1471,7 @@ array max(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = squeeze(out, sorted_axes, s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
@ -1498,7 +1508,7 @@ array min(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
out = squeeze(out, sorted_axes, s);
|
out = squeeze(out, sorted_axes, s);
|
||||||
@ -1538,7 +1548,7 @@ array argmin(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
uint32,
|
uint32,
|
||||||
std::make_unique<ArgReduce>(
|
std::make_shared<ArgReduce>(
|
||||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
@ -1571,7 +1581,7 @@ array argmax(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
uint32,
|
uint32,
|
||||||
std::make_unique<ArgReduce>(
|
std::make_shared<ArgReduce>(
|
||||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||||
{a});
|
{a});
|
||||||
if (!keepdims) {
|
if (!keepdims) {
|
||||||
@ -1607,7 +1617,7 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
a.shape(), a.dtype(), std::make_unique<Sort>(to_stream(s), axis), {a});
|
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns indices that sort the flattened array. */
|
/** Returns indices that sort the flattened array. */
|
||||||
@ -1637,7 +1647,7 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
a.shape(), uint32, std::make_unique<ArgSort>(to_stream(s), axis), {a});
|
a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1677,7 +1687,7 @@ array partition(
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Partition>(to_stream(s), kth_, axis_),
|
std::make_shared<Partition>(to_stream(s), kth_, axis_),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1718,7 +1728,7 @@ array argpartition(
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
uint32,
|
uint32,
|
||||||
std::make_unique<ArgPartition>(to_stream(s), kth_, axis_),
|
std::make_shared<ArgPartition>(to_stream(s), kth_, axis_),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1787,7 +1797,7 @@ array logsumexp(
|
|||||||
|
|
||||||
array abs(const array& a, StreamOrDevice s /* = {} */) {
|
array abs(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto out =
|
auto out =
|
||||||
array(a.shape(), a.dtype(), std::make_unique<Abs>(to_stream(s)), {a});
|
array(a.shape(), a.dtype(), std::make_shared<Abs>(to_stream(s)), {a});
|
||||||
if (a.dtype() == complex64) {
|
if (a.dtype() == complex64) {
|
||||||
out = astype(out, float32, s);
|
out = astype(out, float32, s);
|
||||||
}
|
}
|
||||||
@ -1800,33 +1810,33 @@ array negative(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
throw std::invalid_argument(msg);
|
throw std::invalid_argument(msg);
|
||||||
}
|
}
|
||||||
return array(
|
return array(
|
||||||
a.shape(), a.dtype(), std::make_unique<Negative>(to_stream(s)), {a});
|
a.shape(), a.dtype(), std::make_shared<Negative>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
array operator-(const array& a) {
|
array operator-(const array& a) {
|
||||||
return negative(a);
|
return negative(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
array sign(const array& a, StreamOrDevice s /* = {} */) {
|
array sign(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
return array(a.shape(), a.dtype(), std::make_unique<Sign>(to_stream(s)), {a});
|
return array(a.shape(), a.dtype(), std::make_shared<Sign>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
array logical_not(const array& a, StreamOrDevice s /* = {} */) {
|
array logical_not(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<LogicalNot>(to_stream(s)),
|
std::make_shared<LogicalNot>(to_stream(s)),
|
||||||
{astype(a, bool_, s)});
|
{astype(a, bool_, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
// Broadcast arrays to a common shape
|
// Broadcast arrays to a common shape
|
||||||
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<LogicalAnd>(to_stream(s)),
|
std::make_shared<LogicalAnd>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
array operator&&(const array& a, const array& b) {
|
array operator&&(const array& a, const array& b) {
|
||||||
return logical_and(a, b);
|
return logical_and(a, b);
|
||||||
@ -1834,13 +1844,13 @@ array operator&&(const array& a, const array& b) {
|
|||||||
|
|
||||||
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
// Broadcast arrays to a common shape
|
// Broadcast arrays to a common shape
|
||||||
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
bool_,
|
bool_,
|
||||||
std::make_unique<LogicalOr>(to_stream(s)),
|
std::make_shared<LogicalOr>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
array operator||(const array& a, const array& b) {
|
array operator||(const array& a, const array& b) {
|
||||||
return logical_or(a, b);
|
return logical_or(a, b);
|
||||||
@ -1854,9 +1864,10 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(), out_type, std::make_unique<Add>(to_stream(s)), inputs);
|
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array operator+(const array& a, const array& b) {
|
array operator+(const array& a, const array& b) {
|
||||||
@ -1866,12 +1877,13 @@ array operator+(const array& a, const array& b) {
|
|||||||
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<Subtract>(to_stream(s)),
|
std::make_shared<Subtract>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array operator-(const array& a, const array& b) {
|
array operator-(const array& a, const array& b) {
|
||||||
@ -1881,12 +1893,13 @@ array operator-(const array& a, const array& b) {
|
|||||||
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<Multiply>(to_stream(s)),
|
std::make_shared<Multiply>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array operator*(const array& a, const array& b) {
|
array operator*(const array& a, const array& b) {
|
||||||
@ -1895,10 +1908,11 @@ array operator*(const array& a, const array& b) {
|
|||||||
|
|
||||||
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||||
auto inputs = broadcast_arrays(
|
auto inputs =
|
||||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
|
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
||||||
}
|
}
|
||||||
array operator/(const array& a, const array& b) {
|
array operator/(const array& a, const array& b) {
|
||||||
return divide(a, b);
|
return divide(a, b);
|
||||||
@ -1919,20 +1933,22 @@ array floor_divide(
|
|||||||
return floor(divide(a, b, s), s);
|
return floor(divide(a, b, s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
|
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs = broadcast_arrays(
|
auto inputs =
|
||||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Remainder>(to_stream(s)),
|
std::make_shared<Remainder>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
array operator%(const array& a, const array& b) {
|
array operator%(const array& a, const array& b) {
|
||||||
return remainder(a, b);
|
return remainder(a, b);
|
||||||
@ -1944,8 +1960,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
if (is_complex(dtype)) {
|
if (is_complex(dtype)) {
|
||||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||||
}
|
}
|
||||||
auto inputs = broadcast_arrays(
|
auto inputs =
|
||||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
{inputs[0].shape(), inputs[0].shape()},
|
{inputs[0].shape(), inputs[0].shape()},
|
||||||
{inputs[0].dtype(), inputs[0].dtype()},
|
{inputs[0].dtype(), inputs[0].dtype()},
|
||||||
@ -1956,23 +1972,25 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<Maximum>(to_stream(s)),
|
std::make_shared<Maximum>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<Minimum>(to_stream(s)),
|
std::make_shared<Minimum>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array floor(const array& a, StreamOrDevice s /* = {} */) {
|
array floor(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
@ -1980,103 +1998,103 @@ array floor(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
throw std::invalid_argument("[floor] Not supported for complex64.");
|
throw std::invalid_argument("[floor] Not supported for complex64.");
|
||||||
}
|
}
|
||||||
return array(
|
return array(
|
||||||
a.shape(), a.dtype(), std::make_unique<Floor>(to_stream(s)), {a});
|
a.shape(), a.dtype(), std::make_shared<Floor>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
array ceil(const array& a, StreamOrDevice s /* = {} */) {
|
array ceil(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
if (a.dtype() == complex64) {
|
if (a.dtype() == complex64) {
|
||||||
throw std::invalid_argument("[floor] Not supported for complex64.");
|
throw std::invalid_argument("[floor] Not supported for complex64.");
|
||||||
}
|
}
|
||||||
return array(a.shape(), a.dtype(), std::make_unique<Ceil>(to_stream(s)), {a});
|
return array(a.shape(), a.dtype(), std::make_shared<Ceil>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
array square(const array& a, StreamOrDevice s /* = {} */) {
|
array square(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
return array(
|
return array(
|
||||||
a.shape(), a.dtype(), std::make_unique<Square>(to_stream(s)), {a});
|
a.shape(), a.dtype(), std::make_shared<Square>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
array exp(const array& a, StreamOrDevice s /* = {} */) {
|
array exp(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(a.shape(), dtype, std::make_unique<Exp>(to_stream(s)), {input});
|
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array sin(const array& a, StreamOrDevice s /* = {} */) {
|
array sin(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(a.shape(), dtype, std::make_unique<Sin>(to_stream(s)), {input});
|
return array(a.shape(), dtype, std::make_shared<Sin>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array cos(const array& a, StreamOrDevice s /* = {} */) {
|
array cos(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(a.shape(), dtype, std::make_unique<Cos>(to_stream(s)), {input});
|
return array(a.shape(), dtype, std::make_shared<Cos>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array tan(const array& a, StreamOrDevice s /* = {} */) {
|
array tan(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(a.shape(), dtype, std::make_unique<Tan>(to_stream(s)), {input});
|
return array(a.shape(), dtype, std::make_shared<Tan>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array arcsin(const array& a, StreamOrDevice s /* = {} */) {
|
array arcsin(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<ArcSin>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<ArcSin>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array arccos(const array& a, StreamOrDevice s /* = {} */) {
|
array arccos(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<ArcCos>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<ArcCos>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<ArcTan>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array sinh(const array& a, StreamOrDevice s /* = {} */) {
|
array sinh(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(a.shape(), dtype, std::make_unique<Sinh>(to_stream(s)), {input});
|
return array(a.shape(), dtype, std::make_shared<Sinh>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array cosh(const array& a, StreamOrDevice s /* = {} */) {
|
array cosh(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(a.shape(), dtype, std::make_unique<Cosh>(to_stream(s)), {input});
|
return array(a.shape(), dtype, std::make_shared<Cosh>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array tanh(const array& a, StreamOrDevice s /* = {} */) {
|
array tanh(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(a.shape(), dtype, std::make_unique<Tanh>(to_stream(s)), {input});
|
return array(a.shape(), dtype, std::make_shared<Tanh>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array arcsinh(const array& a, StreamOrDevice s /* = {} */) {
|
array arcsinh(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<ArcSinh>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<ArcSinh>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array arccosh(const array& a, StreamOrDevice s /* = {} */) {
|
array arccosh(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<ArcCosh>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<ArcCosh>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array arctanh(const array& a, StreamOrDevice s /* = {} */) {
|
array arctanh(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<ArcTanh>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<ArcTanh>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array log(const array& a, StreamOrDevice s /* = {} */) {
|
array log(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
@ -2085,7 +2103,7 @@ array log(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Log>(to_stream(s), Log::Base::e),
|
std::make_shared<Log>(to_stream(s), Log::Base::e),
|
||||||
{input});
|
{input});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2095,7 +2113,7 @@ array log2(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Log>(to_stream(s), Log::Base::two),
|
std::make_shared<Log>(to_stream(s), Log::Base::two),
|
||||||
{input});
|
{input});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2105,7 +2123,7 @@ array log10(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Log>(to_stream(s), Log::Base::ten),
|
std::make_shared<Log>(to_stream(s), Log::Base::ten),
|
||||||
{input});
|
{input});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2113,26 +2131,27 @@ array log1p(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<Log1p>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<Log1p>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
// Make sure out type is floating point
|
// Make sure out type is floating point
|
||||||
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
|
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||||
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(),
|
shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<LogAddExp>(to_stream(s)),
|
std::make_shared<LogAddExp>(to_stream(s)),
|
||||||
inputs);
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
array sigmoid(const array& a, StreamOrDevice s /* = {} */) {
|
array sigmoid(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto input = astype(a, dtype, s);
|
auto input = astype(a, dtype, s);
|
||||||
return array(
|
return array(
|
||||||
a.shape(), dtype, std::make_unique<Sigmoid>(to_stream(s)), {input});
|
a.shape(), dtype, std::make_shared<Sigmoid>(to_stream(s)), {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
array erf(const array& a, StreamOrDevice s /* = {} */) {
|
array erf(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
@ -2140,7 +2159,7 @@ array erf(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Erf>(to_stream(s)),
|
std::make_shared<Erf>(to_stream(s)),
|
||||||
{astype(a, dtype, s)});
|
{astype(a, dtype, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2149,19 +2168,19 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<ErfInv>(to_stream(s)),
|
std::make_shared<ErfInv>(to_stream(s)),
|
||||||
{astype(a, dtype, s)});
|
{astype(a, dtype, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
|
array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
return array(
|
return array(
|
||||||
a.shape(), a.dtype(), std::make_unique<StopGradient>(to_stream(s)), {a});
|
a.shape(), a.dtype(), std::make_shared<StopGradient>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
array round(const array& a, int decimals, StreamOrDevice s /* = {} */) {
|
array round(const array& a, int decimals, StreamOrDevice s /* = {} */) {
|
||||||
if (decimals == 0) {
|
if (decimals == 0) {
|
||||||
return array(
|
return array(
|
||||||
a.shape(), a.dtype(), std::make_unique<Round>(to_stream(s)), {a});
|
a.shape(), a.dtype(), std::make_shared<Round>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
@ -2226,7 +2245,7 @@ array matmul(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
{a.shape(0), b.shape(1)},
|
{a.shape(0), b.shape(1)},
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<Matmul>(to_stream(s)),
|
std::make_shared<Matmul>(to_stream(s)),
|
||||||
{a, b});
|
{a, b});
|
||||||
return reshape(out, out_shape, s);
|
return reshape(out, out_shape, s);
|
||||||
}
|
}
|
||||||
@ -2250,17 +2269,17 @@ array matmul(
|
|||||||
auto out_shape = a.shape();
|
auto out_shape = a.shape();
|
||||||
out_shape.back() = b.shape(-1);
|
out_shape.back() = b.shape(-1);
|
||||||
|
|
||||||
auto out = array(
|
auto p = std::make_shared<Matmul>(to_stream(s));
|
||||||
out_shape, out_type, std::make_unique<Matmul>(to_stream(s)), {a, b});
|
|
||||||
|
|
||||||
// Remove the possibly inserted singleton dimensions
|
// Remove the possibly inserted singleton dimensions
|
||||||
if (in_a.ndim() == 1 || in_b.ndim() == 1) {
|
if (in_a.ndim() == 1 || in_b.ndim() == 1) {
|
||||||
|
auto out = array(out_shape, out_type, std::move(p), {a, b});
|
||||||
out_shape.erase(
|
out_shape.erase(
|
||||||
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
|
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
|
||||||
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
|
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
|
||||||
out = reshape(out, out_shape, s);
|
return reshape(out, std::move(out_shape), s);
|
||||||
}
|
}
|
||||||
return out;
|
return array(std::move(out_shape), out_type, std::move(p), {a, b});
|
||||||
}
|
}
|
||||||
|
|
||||||
array gather(
|
array gather(
|
||||||
@ -2332,7 +2351,7 @@ array gather(
|
|||||||
return array(
|
return array(
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Gather>(to_stream(s), axes, slice_sizes),
|
std::make_shared<Gather>(to_stream(s), axes, slice_sizes),
|
||||||
inputs);
|
inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2516,7 +2535,7 @@ array scatter(
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Scatter>(to_stream(s), mode, axes),
|
std::make_shared<Scatter>(to_stream(s), mode, axes),
|
||||||
inputs);
|
inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2570,7 +2589,7 @@ array sqrt(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Sqrt>(to_stream(s)),
|
std::make_shared<Sqrt>(to_stream(s)),
|
||||||
{astype(a, dtype, s)});
|
{astype(a, dtype, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2579,7 +2598,7 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Sqrt>(to_stream(s), true),
|
std::make_shared<Sqrt>(to_stream(s), true),
|
||||||
{astype(a, dtype, s)});
|
{astype(a, dtype, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2592,7 +2611,7 @@ array softmax(
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<Softmax>(to_stream(s)),
|
std::make_shared<Softmax>(to_stream(s)),
|
||||||
{astype(a, dtype, s)});
|
{astype(a, dtype, s)});
|
||||||
} else {
|
} else {
|
||||||
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
|
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
|
||||||
@ -2614,7 +2633,7 @@ array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
inputs = broadcast_arrays(inputs, s);
|
inputs = broadcast_arrays(inputs, s);
|
||||||
}
|
}
|
||||||
return array(
|
return array(
|
||||||
inputs[0].shape(), dtype, std::make_unique<Power>(to_stream(s)), inputs);
|
inputs[0].shape(), dtype, std::make_shared<Power>(to_stream(s)), inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
array cumsum(
|
array cumsum(
|
||||||
@ -2635,7 +2654,7 @@ array cumsum(
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<Scan>(
|
std::make_shared<Scan>(
|
||||||
to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),
|
to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
@ -2657,7 +2676,7 @@ array cumprod(
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Scan>(
|
std::make_shared<Scan>(
|
||||||
to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),
|
to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
@ -2679,7 +2698,7 @@ array cummax(
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Scan>(
|
std::make_shared<Scan>(
|
||||||
to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive),
|
to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
@ -2701,7 +2720,7 @@ array cummin(
|
|||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::make_unique<Scan>(
|
std::make_shared<Scan>(
|
||||||
to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive),
|
to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
@ -2965,7 +2984,7 @@ array conv_general(
|
|||||||
return array(
|
return array(
|
||||||
out_shape,
|
out_shape,
|
||||||
in.dtype(),
|
in.dtype(),
|
||||||
std::make_unique<Convolution>(
|
std::make_shared<Convolution>(
|
||||||
to_stream(s),
|
to_stream(s),
|
||||||
stride,
|
stride,
|
||||||
padding_lo,
|
padding_lo,
|
||||||
@ -3042,7 +3061,7 @@ array quantized_matmul(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto dtype = result_type({x, scales, biases});
|
auto dtype = result_type(x, scales, biases);
|
||||||
if (!is_floating_point(dtype) || is_complex(dtype)) {
|
if (!is_floating_point(dtype) || is_complex(dtype)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||||
@ -3055,7 +3074,7 @@ array quantized_matmul(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
{x.shape(0), w_outer_dims},
|
{x.shape(0), w_outer_dims},
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s), group_size, bits, transpose),
|
to_stream(s), group_size, bits, transpose),
|
||||||
{astype(x, dtype, s),
|
{astype(x, dtype, s),
|
||||||
w,
|
w,
|
||||||
@ -3065,7 +3084,7 @@ array quantized_matmul(
|
|||||||
// If needed reshape x to the original batch shape
|
// If needed reshape x to the original batch shape
|
||||||
if (original_shape.size() != 1) {
|
if (original_shape.size() != 1) {
|
||||||
original_shape.push_back(w_outer_dims);
|
original_shape.push_back(w_outer_dims);
|
||||||
out = reshape(out, original_shape, s);
|
out = reshape(out, std::move(original_shape), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
@ -3344,7 +3363,7 @@ array addmm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Type promotion
|
// Type promotion
|
||||||
auto out_type = result_type({a, b, c});
|
auto out_type = result_type(a, b, c);
|
||||||
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[addmm] Only real floating point types are supported but "
|
msg << "[addmm] Only real floating point types are supported but "
|
||||||
@ -3373,7 +3392,7 @@ array addmm(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
{a.shape(0), b.shape(1)},
|
{a.shape(0), b.shape(1)},
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<AddMM>(to_stream(s), alpha, beta),
|
std::make_shared<AddMM>(to_stream(s), alpha, beta),
|
||||||
{a, b, c});
|
{a, b, c});
|
||||||
return reshape(out, out_shape, s);
|
return reshape(out, out_shape, s);
|
||||||
}
|
}
|
||||||
@ -3425,7 +3444,7 @@ array addmm(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
out_shape,
|
out_shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_unique<AddMM>(to_stream(s), alpha, beta),
|
std::make_shared<AddMM>(to_stream(s), alpha, beta),
|
||||||
{a, b, c});
|
{a, b, c});
|
||||||
|
|
||||||
// Remove the possibly inserted singleton dimensions
|
// Remove the possibly inserted singleton dimensions
|
||||||
@ -3621,7 +3640,7 @@ array number_of_elements(
|
|||||||
return stop_gradient(array(
|
return stop_gradient(array(
|
||||||
std::vector<int>{},
|
std::vector<int>{},
|
||||||
dtype,
|
dtype,
|
||||||
std::make_unique<NumberOfElements>(
|
std::make_shared<NumberOfElements>(
|
||||||
to_stream(s), std::move(axes), inverted, dtype),
|
to_stream(s), std::move(axes), inverted, dtype),
|
||||||
{a}));
|
{a}));
|
||||||
}
|
}
|
||||||
|
@ -158,9 +158,7 @@ array expand_dims(
|
|||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Add a singleton dimension at the given axis. */
|
/** Add a singleton dimension at the given axis. */
|
||||||
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
|
array expand_dims(const array& a, int axis, StreamOrDevice s = {});
|
||||||
return expand_dims(a, std::vector<int>{axis}, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Slice an array. */
|
/** Slice an array. */
|
||||||
array slice(
|
array slice(
|
||||||
|
@ -66,7 +66,7 @@ array bits(
|
|||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
get_dtype(),
|
get_dtype(),
|
||||||
std::make_unique<RandomBits>(to_stream(s), shape, width),
|
std::make_shared<RandomBits>(to_stream(s), shape, width),
|
||||||
{key});
|
{key});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +54,12 @@ struct PrintFormatter {
|
|||||||
extern PrintFormatter global_formatter;
|
extern PrintFormatter global_formatter;
|
||||||
|
|
||||||
/** The type from promoting the arrays' types with one another. */
|
/** The type from promoting the arrays' types with one another. */
|
||||||
|
inline Dtype result_type(const array& a, const array& b) {
|
||||||
|
return promote_types(a.dtype(), b.dtype());
|
||||||
|
}
|
||||||
|
inline Dtype result_type(const array& a, const array& b, const array& c) {
|
||||||
|
return promote_types(result_type(a, b), c.dtype());
|
||||||
|
}
|
||||||
Dtype result_type(const std::vector<array>& arrays);
|
Dtype result_type(const std::vector<array>& arrays);
|
||||||
|
|
||||||
std::vector<int> broadcast_shapes(
|
std::vector<int> broadcast_shapes(
|
||||||
|
@ -131,8 +131,8 @@ class Module(dict):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
if key in self:
|
if (value := self.get(key, None)) is not None:
|
||||||
return self[key]
|
return value
|
||||||
else:
|
else:
|
||||||
super(Module, self).__getattribute__(key)
|
super(Module, self).__getattribute__(key)
|
||||||
|
|
||||||
|
@ -64,9 +64,9 @@ class Linear(Module):
|
|||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = mx.addmm(self.bias, x, self.weight.T)
|
x = mx.addmm(self["bias"], x, self["weight"].T)
|
||||||
else:
|
else:
|
||||||
x = x @ self.weight.T
|
x = x @ self["weight"].T
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,7 +140,7 @@ class RMSNorm(Module):
|
|||||||
return f"{self.weight.shape[0]}, eps={self.eps}"
|
return f"{self.weight.shape[0]}, eps={self.eps}"
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
return mx.fast.rms_norm(x, self["weight"], self.eps)
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(Module):
|
class GroupNorm(Module):
|
||||||
|
@ -81,15 +81,15 @@ class QuantizedLinear(Module):
|
|||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = mx.quantized_matmul(
|
x = mx.quantized_matmul(
|
||||||
x,
|
x,
|
||||||
self.weight,
|
self["weight"],
|
||||||
scales=self.scales,
|
scales=self["scales"],
|
||||||
biases=self.biases,
|
biases=self["biases"],
|
||||||
transpose=True,
|
transpose=True,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = x + self.bias
|
x = x + self["bias"]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -17,12 +17,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"rms_norm",
|
"rms_norm",
|
||||||
[](const array& x,
|
&fast::rms_norm,
|
||||||
const array& weight,
|
|
||||||
float eps,
|
|
||||||
const StreamOrDevice& s /* = {} */) {
|
|
||||||
return fast::rms_norm(x, weight, eps, s);
|
|
||||||
},
|
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"weight"_a,
|
"weight"_a,
|
||||||
"eps"_a,
|
"eps"_a,
|
||||||
@ -48,13 +43,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"layer_norm",
|
"layer_norm",
|
||||||
[](const array& x,
|
&fast::layer_norm,
|
||||||
const std::optional<array>& weight,
|
|
||||||
const std::optional<array>& bias,
|
|
||||||
float eps,
|
|
||||||
const StreamOrDevice& s /* = {} */) {
|
|
||||||
return fast::layer_norm(x, weight, bias, eps, s);
|
|
||||||
},
|
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"weight"_a.none(),
|
"weight"_a.none(),
|
||||||
"bias"_a.none(),
|
"bias"_a.none(),
|
||||||
@ -84,15 +73,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"rope",
|
"rope",
|
||||||
[](const array& a,
|
&fast::rope,
|
||||||
int dims,
|
|
||||||
bool traditional,
|
|
||||||
float base,
|
|
||||||
float scale,
|
|
||||||
int offset,
|
|
||||||
const StreamOrDevice& s /* = {} */) {
|
|
||||||
return fast::rope(a, dims, traditional, base, scale, offset, s);
|
|
||||||
},
|
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"dims"_a,
|
"dims"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -123,14 +104,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"scaled_dot_product_attention",
|
"scaled_dot_product_attention",
|
||||||
[](const array& q,
|
&fast::scaled_dot_product_attention,
|
||||||
const array& k,
|
|
||||||
const array& v,
|
|
||||||
const float scale,
|
|
||||||
const std::optional<array>& mask,
|
|
||||||
const StreamOrDevice& s) {
|
|
||||||
return fast::scaled_dot_product_attention(q, k, v, scale, mask, s);
|
|
||||||
},
|
|
||||||
"q"_a,
|
"q"_a,
|
||||||
"k"_a,
|
"k"_a,
|
||||||
"v"_a,
|
"v"_a,
|
||||||
|
Loading…
Reference in New Issue
Block a user