Shape and Strides 1 / N (#1645)

* shape and stride type def

* more shape
This commit is contained in:
Awni Hannun 2024-12-05 12:53:43 -08:00 committed by GitHub
parent c5b0928c1f
commit fc88fd9097
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 178 additions and 242 deletions

View File

@ -31,7 +31,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
} }
array::array( array::array(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs) std::vector<array> inputs)
@ -42,7 +42,7 @@ array::array(
std::move(inputs))) {} std::move(inputs))) {}
std::vector<array> array::make_arrays( std::vector<array> array::make_arrays(
std::vector<std::vector<int>> shapes, std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes, const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) { const std::vector<array>& inputs) {
@ -74,11 +74,7 @@ array::array(std::initializer_list<int> data, Dtype dtype)
} }
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array( array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
allocator::Buffer data,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter); set_data(data, deleter);
} }
@ -126,7 +122,7 @@ bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph(); return array_desc_->is_tracer && in_tracing() || retain_graph();
} }
void array::set_data(allocator::Buffer buffer, deleter_t d) { void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size(); array_desc_->data_size = size();
@ -139,9 +135,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) {
void array::set_data( void array::set_data(
allocator::Buffer buffer, allocator::Buffer buffer,
size_t data_size, size_t data_size,
std::vector<size_t> strides, Strides strides,
Flags flags, Flags flags,
deleter_t d) { Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
@ -151,7 +147,7 @@ void array::set_data(
void array::copy_shared_buffer( void array::copy_shared_buffer(
const array& other, const array& other,
const std::vector<size_t>& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset /* = 0 */) { size_t offset /* = 0 */) {
@ -170,7 +166,7 @@ void array::copy_shared_buffer(const array& other) {
void array::move_shared_buffer( void array::move_shared_buffer(
array other, array other,
const std::vector<size_t>& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset /* = 0 */) { size_t offset /* = 0 */) {
@ -237,13 +233,13 @@ void array::ArrayDesc::init() {
} }
} }
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype) array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) { : shape(std::move(shape)), dtype(dtype), status(Status::available) {
init(); init();
} }
array::ArrayDesc::ArrayDesc( array::ArrayDesc::ArrayDesc(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs) std::vector<array> inputs)

View File

@ -15,7 +15,10 @@ namespace mlx::core {
// Forward declaration // Forward declaration
class Primitive; class Primitive;
using deleter_t = std::function<void(allocator::Buffer)>;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using Strides = std::vector<size_t>;
class array { class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc /* An array is really a node in a graph. It contains a shared ArrayDesc
@ -33,7 +36,7 @@ class array {
template <typename It> template <typename It>
array( array(
It data, It data,
std::vector<int> shape, Shape shape,
Dtype dtype = Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>()); TypeToDtype<typename std::iterator_traits<It>::value_type>());
@ -49,15 +52,15 @@ class array {
template <typename T> template <typename T>
array( array(
std::initializer_list<T> data, std::initializer_list<T> data,
std::vector<int> shape, Shape shape,
Dtype dtype = TypeToDtype<T>()); Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */ /* Build an array from a buffer */
array( array(
allocator::Buffer data, allocator::Buffer data,
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
deleter_t deleter = allocator::free); Deleter deleter = allocator::free);
/** Assignment to rvalue does not compile. */ /** Assignment to rvalue does not compile. */
array& operator=(const array& other) && = delete; array& operator=(const array& other) && = delete;
@ -96,7 +99,7 @@ class array {
} }
/** The shape of the array as a vector of integers. */ /** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const { const Shape& shape() const {
return array_desc_->shape; return array_desc_->shape;
} }
@ -105,12 +108,12 @@ class array {
* *
* This function supports negative indexing and provides * This function supports negative indexing and provides
* bounds checking. */ * bounds checking. */
int shape(int dim) const { auto shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim); return shape().at(dim < 0 ? dim + ndim() : dim);
} }
/** The strides of the array. */ /** The strides of the array. */
const std::vector<size_t>& strides() const { const Strides& strides() const {
return array_desc_->strides; return array_desc_->strides;
} }
@ -119,7 +122,7 @@ class array {
* *
* This function supports negative indexing and provides * This function supports negative indexing and provides
* bounds checking. */ * bounds checking. */
size_t strides(int dim) const { auto strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim); return strides().at(dim < 0 ? dim + ndim() : dim);
} }
@ -184,13 +187,13 @@ class array {
*/ */
array( array(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs); std::vector<array> inputs);
static std::vector<array> make_arrays( static std::vector<array> make_arrays(
std::vector<std::vector<int>> shapes, std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes, const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs); const std::vector<array>& inputs);
@ -207,8 +210,8 @@ class array {
struct Data { struct Data {
allocator::Buffer buffer; allocator::Buffer buffer;
deleter_t d; Deleter d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free) Data(allocator::Buffer buffer, Deleter d = allocator::free)
: buffer(buffer), d(d) {} : buffer(buffer), d(d) {}
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
@ -397,18 +400,18 @@ class array {
// Check if the array is a tracer array // Check if the array is a tracer array
bool is_tracer() const; bool is_tracer() const;
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free); void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
void set_data( void set_data(
allocator::Buffer buffer, allocator::Buffer buffer,
size_t data_size, size_t data_size,
std::vector<size_t> strides, Strides strides,
Flags flags, Flags flags,
deleter_t d = allocator::free); Deleter d = allocator::free);
void copy_shared_buffer( void copy_shared_buffer(
const array& other, const array& other,
const std::vector<size_t>& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset = 0); size_t offset = 0);
@ -417,7 +420,7 @@ class array {
void move_shared_buffer( void move_shared_buffer(
array other, array other,
const std::vector<size_t>& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset = 0); size_t offset = 0);
@ -436,8 +439,8 @@ class array {
void init(const It src); void init(const It src);
struct ArrayDesc { struct ArrayDesc {
std::vector<int> shape; Shape shape;
std::vector<size_t> strides; Strides strides;
size_t size; size_t size;
Dtype dtype; Dtype dtype;
std::shared_ptr<Primitive> primitive; std::shared_ptr<Primitive> primitive;
@ -471,10 +474,10 @@ class array {
// The arrays position in the output list // The arrays position in the output list
uint32_t position{0}; uint32_t position{0};
explicit ArrayDesc(std::vector<int> shape, Dtype dtype); explicit ArrayDesc(Shape shape, Dtype dtype);
explicit ArrayDesc( explicit ArrayDesc(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs); std::vector<array> inputs);
@ -502,7 +505,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It> template <typename It>
array::array( array::array(
It data, It data,
std::vector<int> shape, Shape shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) : Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data); init(data);
@ -521,7 +524,7 @@ array::array(
template <typename T> template <typename T>
array::array( array::array(
std::initializer_list<T> data, std::initializer_list<T> data,
std::vector<int> shape, Shape shape,
Dtype dtype /* = TypeToDtype<T>() */) Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) { if (data.size() != size()) {

View File

@ -16,10 +16,9 @@ namespace mlx::core {
namespace { namespace {
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool> std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
compute_reduce_shape(
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& shape) { const Shape& shape) {
bool is_noop = true; bool is_noop = true;
std::set<int> axes_set; std::set<int> axes_set;
auto ndim = shape.size(); auto ndim = shape.size();
@ -36,8 +35,8 @@ compute_reduce_shape(
if (axes_set.size() != axes.size()) { if (axes_set.size() != axes.size()) {
throw std::invalid_argument("Duplicate axes detected in reduction."); throw std::invalid_argument("Duplicate axes detected in reduction.");
} }
std::vector<int> out_shape; Shape out_shape;
std::vector<int> squeezed_shape; Shape squeezed_shape;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) { if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]); out_shape.push_back(shape[i]);
@ -63,7 +62,7 @@ array indices_or_default(
return indices.value(); return indices.value();
} }
std::vector<int> shape(x.shape().begin(), x.shape().end() - 2); Shape shape(x.shape().begin(), x.shape().end() - 2);
int total = int total =
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>()); std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
return reshape(arange(total, uint32, s), shape, s); return reshape(arange(total, uint32, s), shape, s);
@ -254,8 +253,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
array as_strided( array as_strided(
array a, array a,
std::vector<int> shape, Shape shape,
std::vector<size_t> strides, Strides strides,
size_t offset, size_t offset,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto copied_shape = shape; // |shape| will be moved auto copied_shape = shape; // |shape| will be moved
@ -279,12 +278,8 @@ array copy(array a, StreamOrDevice s /* = {} */) {
{std::move(a)}); {std::move(a)});
} }
array full( array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
std::vector<int> shape, if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
array vals,
Dtype dtype,
StreamOrDevice s /* = {} */) {
if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) {
throw std::invalid_argument("[full] Negative dimensions not allowed."); throw std::invalid_argument("[full] Negative dimensions not allowed.");
} }
auto copied_shape = shape; // |shape| will be moved auto copied_shape = shape; // |shape| will be moved
@ -295,15 +290,12 @@ array full(
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)}); {broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
} }
array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) { array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
auto dtype = vals.dtype(); // |vals| will be moved auto dtype = vals.dtype(); // |vals| will be moved
return full(std::move(shape), std::move(vals), dtype, to_stream(s)); return full(std::move(shape), std::move(vals), dtype, to_stream(s));
} }
array zeros( array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
const std::vector<int>& shape,
Dtype dtype,
StreamOrDevice s /* = {} */) {
return full(shape, array(0, dtype), to_stream(s)); return full(shape, array(0, dtype), to_stream(s));
} }
@ -311,10 +303,7 @@ array zeros_like(const array& a, StreamOrDevice s /* = {} */) {
return zeros(a.shape(), a.dtype(), to_stream(s)); return zeros(a.shape(), a.dtype(), to_stream(s));
} }
array ones( array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
const std::vector<int>& shape,
Dtype dtype,
StreamOrDevice s /* = {} */) {
return full(shape, array(1, dtype), to_stream(s)); return full(shape, array(1, dtype), to_stream(s));
} }
@ -368,10 +357,7 @@ array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
return where(mask, zeros_like(x, s), x, s); return where(mask, zeros_like(x, s), x, s);
} }
array reshape( array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
const array& a,
std::vector<int> shape,
StreamOrDevice s /* = {} */) {
if (a.shape() == shape) { if (a.shape() == shape) {
return a; return a;
} }
@ -445,11 +431,11 @@ array flatten(
if (start_ax == end_ax) { if (start_ax == end_ax) {
return a; return a;
} }
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_ax); Shape new_shape(a.shape().begin(), a.shape().begin() + start_ax);
new_shape.push_back(-1); new_shape.push_back(-1);
new_shape.insert( new_shape.insert(
new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end()); new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end());
return reshape(a, new_shape, s); return reshape(a, std::move(new_shape), s);
} }
array flatten(const array& a, StreamOrDevice s /* = {} */) { array flatten(const array& a, StreamOrDevice s /* = {} */) {
@ -496,7 +482,7 @@ array squeeze(
throw std::invalid_argument("[squeeze] Received duplicate axes."); throw std::invalid_argument("[squeeze] Received duplicate axes.");
} }
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end()); std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
std::vector<int> shape; Shape shape;
for (int i = 0, j = 0; i < a.ndim(); ++i) { for (int i = 0, j = 0; i < a.ndim(); ++i) {
if (j < sorted_axes.size() && i == sorted_axes[j]) { if (j < sorted_axes.size() && i == sorted_axes[j]) {
j++; j++;
@ -584,12 +570,9 @@ array expand_dims(
// Slice helper // Slice helper
namespace { namespace {
inline auto normalize_slice( inline auto
const std::vector<int>& shape, normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
std::vector<int>& start, Shape out_shape(shape.size());
std::vector<int>& stop,
std::vector<int>& strides) {
std::vector<int> out_shape(shape.size());
bool has_neg_strides = false; bool has_neg_strides = false;
for (int i = 0; i < shape.size(); ++i) { for (int i = 0; i < shape.size(); ++i) {
@ -641,9 +624,9 @@ inline auto normalize_slice(
array slice( array slice(
const array& a, const array& a,
std::vector<int> start, Shape start,
std::vector<int> stop, Shape stop,
std::vector<int> strides, Shape strides,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (start.size() != a.ndim() || stop.size() != a.ndim() || if (start.size() != a.ndim() || stop.size() != a.ndim() ||
strides.size() != a.ndim()) { strides.size() != a.ndim()) {
@ -670,24 +653,20 @@ array slice(
array slice( array slice(
const array& a, const array& a,
std::vector<int> start, Shape start,
std::vector<int> stop, Shape stop,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return slice( return slice(
a, a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s));
std::move(start),
std::move(stop),
std::vector<int>(a.ndim(), 1),
to_stream(s));
} }
/** Update a slice from the source array */ /** Update a slice from the source array */
array slice_update( array slice_update(
const array& src, const array& src,
const array& update, const array& update,
std::vector<int> start, Shape start,
std::vector<int> stop, Shape stop,
std::vector<int> strides, Shape strides,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Check dimensions // Check dimensions
if (start.size() != src.ndim() || stop.size() != src.ndim() || if (start.size() != src.ndim() || stop.size() != src.ndim() ||
@ -721,12 +700,11 @@ array slice_update(
array slice_update( array slice_update(
const array& src, const array& src,
const array& update, const array& update,
std::vector<int> start, Shape start,
std::vector<int> stop, Shape stop,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto strides = std::vector<int>(src.ndim(), 1);
return slice_update( return slice_update(
src, update, std::move(start), std::move(stop), std::move(strides), s); src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);
} }
std::vector<array> split( std::vector<array> split(
@ -750,7 +728,7 @@ std::vector<array> split(
std::is_sorted(indices.begin(), indices.end(), std::less<>{}) && std::is_sorted(indices.begin(), indices.end(), std::less<>{}) &&
indices[0] > 0 && indices.back() < a.shape(ax)) { indices[0] > 0 && indices.back() < a.shape(ax)) {
std::vector<Dtype> dtypes(indices.size() + 1, a.dtype()); std::vector<Dtype> dtypes(indices.size() + 1, a.dtype());
std::vector<std::vector<int>> shapes(indices.size() + 1, a.shape()); std::vector<Shape> shapes(indices.size() + 1, a.shape());
shapes[0][ax] = indices[0]; shapes[0][ax] = indices[0];
for (int i = 1; i < indices.size(); i++) { for (int i = 1; i < indices.size(); i++) {
shapes[i][ax] = indices[i] - indices[i - 1]; shapes[i][ax] = indices[i] - indices[i - 1];
@ -765,8 +743,7 @@ std::vector<array> split(
} }
std::vector<array> res; std::vector<array> res;
auto out_shape = a.shape(); auto start_indices = Shape(a.ndim(), 0);
auto start_indices = std::vector<int>(a.ndim(), 0);
auto stop_indices = a.shape(); auto stop_indices = a.shape();
for (int i = 0; i < indices.size() + 1; ++i) { for (int i = 0; i < indices.size() + 1; ++i) {
stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax); stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax);
@ -826,13 +803,13 @@ std::vector<array> meshgrid(
auto ndim = arrays.size(); auto ndim = arrays.size();
std::vector<array> outputs; std::vector<array> outputs;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
std::vector<int> shape(ndim, 1); Shape shape(ndim, 1);
shape[i] = -1; shape[i] = -1;
outputs.push_back(reshape(arrays[i], std::move(shape), s)); outputs.push_back(reshape(arrays[i], std::move(shape), s));
} }
if (indexing == "xy" and ndim > 1) { if (indexing == "xy" and ndim > 1) {
std::vector<int> shape(ndim, 1); Shape shape(ndim, 1);
shape[1] = arrays[0].size(); shape[1] = arrays[0].size();
outputs[0] = reshape(arrays[0], shape, s); outputs[0] = reshape(arrays[0], shape, s);
@ -895,7 +872,7 @@ array concatenate(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
}; };
std::vector<int> shape = arrays[0].shape(); auto shape = arrays[0].shape();
shape[ax] = 0; shape[ax] = 0;
// Make the output shape and validate that all arrays have the same shape // Make the output shape and validate that all arrays have the same shape
// except for the concatenation axis. // except for the concatenation axis.
@ -980,7 +957,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
} }
// Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...) // Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)
std::vector<int> shape(arr.shape()); auto shape = arr.shape();
shape.insert(shape.begin() + axis + 1, repeats); shape.insert(shape.begin() + axis + 1, repeats);
array out = expand_dims(arr, axis + 1, s); array out = expand_dims(arr, axis + 1, s);
out = broadcast_to(out, shape, s); out = broadcast_to(out, shape, s);
@ -1009,9 +986,9 @@ array tile(
shape.insert(shape.begin(), reps.size() - shape.size(), 1); shape.insert(shape.begin(), reps.size() - shape.size(), 1);
} }
std::vector<int> expand_shape; Shape expand_shape;
std::vector<int> broad_shape; Shape broad_shape;
std::vector<int> final_shape; Shape final_shape;
for (int i = 0; i < shape.size(); i++) { for (int i = 0; i < shape.size(); i++) {
if (reps[i] != 1) { if (reps[i] != 1) {
expand_shape.push_back(1); expand_shape.push_back(1);
@ -1022,17 +999,17 @@ array tile(
final_shape.push_back(reps[i] * shape[i]); final_shape.push_back(reps[i] * shape[i]);
} }
auto x = reshape(arr, expand_shape, s); auto x = reshape(arr, std::move(expand_shape), s);
x = broadcast_to(x, broad_shape, s); x = broadcast_to(x, std::move(broad_shape), s);
return reshape(x, final_shape, s); return reshape(x, std::move(final_shape), s);
} }
array edge_pad( array edge_pad(
const array& a, const array& a,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& low_pad_size, const Shape& low_pad_size,
const std::vector<int>& high_pad_size, const Shape& high_pad_size,
const std::vector<int>& out_shape, const Shape& out_shape,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
array out = zeros(out_shape, a.dtype(), s); array out = zeros(out_shape, a.dtype(), s);
auto stops = a.shape(); auto stops = a.shape();
@ -1044,7 +1021,7 @@ array edge_pad(
for (int axis = 0; axis < a.ndim(); axis++) { for (int axis = 0; axis < a.ndim(); axis++) {
if (low_pad_size[axis] > 0) { if (low_pad_size[axis] > 0) {
std::vector<int> starts(a.ndim(), 0); Shape starts(a.ndim(), 0);
starts[axis] = low_pad_size[axis]; starts[axis] = low_pad_size[axis];
auto stops = out.shape(); auto stops = out.shape();
stops[axis] = low_pad_size[axis] + 1; stops[axis] = low_pad_size[axis] + 1;
@ -1058,7 +1035,7 @@ array edge_pad(
} }
if (high_pad_size[axis] > 0) { if (high_pad_size[axis] > 0) {
std::vector<int> starts(a.ndim(), 0); Shape starts(a.ndim(), 0);
starts[axis] = -high_pad_size[axis] - 1; starts[axis] = -high_pad_size[axis] - 1;
auto stops = out.shape(); auto stops = out.shape();
stops[axis] = -high_pad_size[axis]; stops[axis] = -high_pad_size[axis];
@ -1075,9 +1052,9 @@ array edge_pad(
/** Pad an array with a constant value */ /** Pad an array with a constant value */
array pad( array pad(
const array& a, const array& a,
const std::vector<int>& axes, const Shape& axes,
const std::vector<int>& low_pad_size, const Shape& low_pad_size,
const std::vector<int>& high_pad_size, const Shape& high_pad_size,
const array& pad_value /*= array(0)*/, const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/, const std::string mode /*= "constant"*/,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
@ -1089,7 +1066,7 @@ array pad(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
std::vector<int> out_shape = a.shape(); auto out_shape = a.shape();
for (int i = 0; i < axes.size(); i++) { for (int i = 0; i < axes.size(); i++) {
if (low_pad_size[i] < 0) { if (low_pad_size[i] < 0) {
@ -1113,7 +1090,7 @@ array pad(
if (mode == "constant") { if (mode == "constant") {
return array( return array(
out_shape, std::move(out_shape),
a.dtype(), a.dtype(),
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size), std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
{a, astype(pad_value, a.dtype(), s)}); {a, astype(pad_value, a.dtype(), s)});
@ -1136,8 +1113,8 @@ array pad(
std::vector<int> axes(a.ndim(), 0); std::vector<int> axes(a.ndim(), 0);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
std::vector<int> lows; Shape lows;
std::vector<int> highs; Shape highs;
for (auto& pads : pad_width) { for (auto& pads : pad_width) {
lows.push_back(pads.first); lows.push_back(pads.first);
@ -1240,7 +1217,7 @@ array transpose(
} }
// Check in bounds and for duplicates // Check in bounds and for duplicates
std::vector<int> shape(axes.size(), 0); Shape shape(axes.size(), 0);
for (auto& ax : axes) { for (auto& ax : axes) {
if (ax < 0 || ax >= a.ndim()) { if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg; std::ostringstream msg;
@ -1272,7 +1249,7 @@ array transpose(const array& a, StreamOrDevice s /* = {} */) {
array broadcast_to( array broadcast_to(
const array& a, const array& a,
const std::vector<int>& shape, const Shape& shape,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (a.shape() == shape) { if (a.shape() == shape) {
return a; return a;
@ -1295,14 +1272,14 @@ array broadcast_to(
std::vector<array> std::vector<array>
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) { broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
std::vector<int> shape = broadcast_shapes(a.shape(), b.shape()); auto shape = broadcast_shapes(a.shape(), b.shape());
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)}; return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
} }
std::vector<array> broadcast_arrays( std::vector<array> broadcast_arrays(
const std::vector<array>& inputs, const std::vector<array>& inputs,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
std::vector<int> shape{}; Shape shape{};
for (const auto& in : inputs) { for (const auto& in : inputs) {
shape = broadcast_shapes(shape, in.shape()); shape = broadcast_shapes(shape, in.shape());
} }
@ -1913,7 +1890,7 @@ array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size(); int size = a.size();
auto result = argmax(reshape(a, {size}, s), 0, true, s); auto result = argmax(reshape(a, {size}, s), 0, true, s);
if (keepdims) { if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s); result = reshape(result, Shape(a.shape().size(), 1), s);
} else { } else {
result = squeeze(result, s); result = squeeze(result, s);
} }
@ -2098,8 +2075,8 @@ array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
} }
array a_partitioned = partition(a, -k, axis_, s); array a_partitioned = partition(a, -k, axis_, s);
std::vector<int> slice_starts(a.ndim(), 0); Shape slice_starts(a.ndim(), 0);
std::vector<int> slice_ends = a.shape(); auto slice_ends = a.shape();
slice_starts[axis_] = a.shape(axis_) - k; slice_starts[axis_] = a.shape(axis_) - k;
return slice(a_partitioned, slice_starts, slice_ends, s); return slice(a_partitioned, slice_starts, slice_ends, s);
} }
@ -2613,8 +2590,8 @@ array matmul(
} }
if (a.ndim() > 2 || b.ndim() > 2) { if (a.ndim() > 2 || b.ndim() > 2) {
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2); Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2); Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_a, bsx_b); auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
// Broadcast a // Broadcast a
@ -2648,7 +2625,7 @@ array gather(
const array& a, const array& a,
const std::vector<array>& indices, const std::vector<array>& indices,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& slice_sizes, const Shape& slice_sizes,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Checks that indices, dimensions, and slice_sizes are all valid // Checks that indices, dimensions, and slice_sizes are all valid
if (indices.size() > a.ndim()) { if (indices.size() > a.ndim()) {
@ -2703,7 +2680,7 @@ array gather(
idx = astype(idx, dtype, s); idx = astype(idx, dtype, s);
} }
std::vector<int> out_shape; Shape out_shape;
if (!inputs.empty()) { if (!inputs.empty()) {
out_shape = inputs[0].shape(); out_shape = inputs[0].shape();
} }
@ -2741,7 +2718,7 @@ array take(
axis = axis < 0 ? a.ndim() + axis : axis; axis = axis < 0 ? a.ndim() + axis : axis;
// Make slice sizes to pass to gather // Make slice sizes to pass to gather
std::vector<int> slice_sizes = a.shape(); Shape slice_sizes = a.shape();
slice_sizes[axis] = indices.size() > 0 ? 1 : 0; slice_sizes[axis] = indices.size() > 0 ? 1 : 0;
auto out = gather(a, indices, axis, slice_sizes, s); auto out = gather(a, indices, axis, slice_sizes, s);
@ -2759,7 +2736,7 @@ array take(
} }
// Squeeze the axis we take over // Squeeze the axis we take over
std::vector<int> out_shape = out.shape(); auto out_shape = out.shape();
out_shape.erase(out_shape.begin() + indices.ndim() + axis); out_shape.erase(out_shape.begin() + indices.ndim() + axis);
return reshape(out, std::move(out_shape), s); return reshape(out, std::move(out_shape), s);
} }
@ -2787,8 +2764,8 @@ array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
// Handle negative axis // Handle negative axis
axis = axis < 0 ? a.ndim() + axis : axis; axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<int> starts(a.ndim(), 0); Shape starts(a.ndim(), 0);
std::vector<int> stops = a.shape(); Shape stops = a.shape();
starts[axis] = index; starts[axis] = index;
stops[axis] = index + 1; stops[axis] = index + 1;
return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s); return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);
@ -2821,7 +2798,7 @@ array take_along_axis(
axis = axis < 0 ? a.ndim() + axis : axis; axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<array> nd_indices; std::vector<array> nd_indices;
std::vector<int> index_shape(a.ndim(), 1); Shape index_shape(a.ndim(), 1);
for (int i = 0; i < a.ndim(); ++i) { for (int i = 0; i < a.ndim(); ++i) {
if (i == axis) { if (i == axis) {
nd_indices.push_back(indices); nd_indices.push_back(indices);
@ -2834,12 +2811,11 @@ array take_along_axis(
} }
std::vector<int> dims(a.ndim()); std::vector<int> dims(a.ndim());
std::iota(dims.begin(), dims.end(), 0); std::iota(dims.begin(), dims.end(), 0);
std::vector<int> slice_sizes(a.ndim(), a.size() > 0); Shape slice_sizes(a.ndim(), a.size() > 0);
auto out = gather(a, nd_indices, dims, slice_sizes, s); auto out = gather(a, nd_indices, dims, slice_sizes, s);
// Squeeze out the slice shape // Squeeze out the slice shape
std::vector<int> out_shape( Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim());
out.shape().begin(), out.shape().begin() + a.ndim());
return reshape(out, std::move(out_shape), s); return reshape(out, std::move(out_shape), s);
} }
@ -2867,7 +2843,7 @@ array put_along_axis(
axis = axis < 0 ? a.ndim() + axis : axis; axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<array> nd_indices; std::vector<array> nd_indices;
std::vector<int> index_shape(a.ndim(), 1); Shape index_shape(a.ndim(), 1);
for (int i = 0; i < a.ndim(); ++i) { for (int i = 0; i < a.ndim(); ++i) {
if (i == axis) { if (i == axis) {
nd_indices.push_back(indices); nd_indices.push_back(indices);
@ -2927,7 +2903,7 @@ array scatter(
// Broadcast and cast indices if necessary // Broadcast and cast indices if necessary
auto inputs = broadcast_arrays(indices); auto inputs = broadcast_arrays(indices);
std::vector<int> idx_shape; Shape idx_shape;
if (!inputs.empty()) { if (!inputs.empty()) {
idx_shape = inputs[0].shape(); idx_shape = inputs[0].shape();
} }
@ -3198,7 +3174,7 @@ inline int dilate_size(int dim, int dil) {
return 1 + dil * (dim - 1); return 1 + dil * (dim - 1);
} }
inline std::vector<int> conv_out_shape( Shape conv_out_shape(
const std::vector<int>& in_shape, const std::vector<int>& in_shape,
const std::vector<int>& wt_shape, const std::vector<int>& wt_shape,
const std::vector<int>& strides, const std::vector<int>& strides,
@ -3208,7 +3184,7 @@ inline std::vector<int> conv_out_shape(
const std::vector<int>& input_dilation) { const std::vector<int>& input_dilation) {
int N = in_shape[0]; int N = in_shape[0];
int O = wt_shape[0]; int O = wt_shape[0];
std::vector<int> out_shape(in_shape.size()); Shape out_shape(in_shape.size());
int i = 0; int i = 0;
out_shape[i++] = N; out_shape[i++] = N;
@ -3577,8 +3553,8 @@ array conv_general(
// Handle negative padding // Handle negative padding
if (has_neg_padding) { if (has_neg_padding) {
std::vector<int> starts(in.ndim(), 0); Shape starts(in.ndim(), 0);
std::vector<int> stops = in.shape(); auto stops = in.shape();
for (int i = 0; i < spatial_dims; i++) { for (int i = 0; i < spatial_dims; i++) {
if (padding_lo[i] < 0) { if (padding_lo[i] < 0) {
@ -3596,7 +3572,7 @@ array conv_general(
} }
// Get output shapes // Get output shapes
std::vector<int> out_shape = conv_out_shape( auto out_shape = conv_out_shape(
in.shape(), in.shape(),
wt.shape(), wt.shape(),
stride, stride,
@ -3606,7 +3582,7 @@ array conv_general(
input_dilation); input_dilation);
return array( return array(
out_shape, std::move(out_shape),
in.dtype(), in.dtype(),
std::make_shared<Convolution>( std::make_shared<Convolution>(
to_stream(s), to_stream(s),
@ -3634,8 +3610,8 @@ array quantized_matmul(
// QuantizedMatmul handles w.ndim == 2 case. // QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) { if (x.ndim() > 2 && w.ndim() > 2) {
std::vector<int> bsx_x(x.shape().begin(), x.shape().end() - 2); Shape bsx_x(x.shape().begin(), x.shape().end() - 2);
std::vector<int> bsx_w(w.shape().begin(), w.shape().end() - 2); Shape bsx_w(w.shape().begin(), w.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_x, bsx_w); auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
// Broadcast x // Broadcast x
@ -3731,7 +3707,7 @@ array gather_qmm(
// and output type // and output type
auto out_type = result_type(x, scales, biases); auto out_type = result_type(x, scales, biases);
auto out = array( return array(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose), std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
@ -3741,8 +3717,6 @@ array gather_qmm(
astype(biases, out_type, s), astype(biases, out_type, s),
lhs_indices, lhs_indices,
rhs_indices}); rhs_indices});
return out;
} }
array tensordot( array tensordot(
@ -3802,7 +3776,7 @@ array tensordot(
std::vector<int> t1; std::vector<int> t1;
std::vector<int> t2; std::vector<int> t2;
std::vector<int> rshape; Shape rshape;
int size1 = 1; int size1 = 1;
int size2 = 1; int size2 = 1;
for (int i = 0; i < a.ndim(); i++) { for (int i = 0; i < a.ndim(); i++) {
@ -3898,7 +3872,7 @@ array addmm(
// We can batch the multiplication by reshaping a // We can batch the multiplication by reshaping a
if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) { if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) {
std::vector<int> out_shape = a.shape(); auto out_shape = a.shape();
a = reshape(a, {-1, out_shape.back()}, s); a = reshape(a, {-1, out_shape.back()}, s);
out_shape.back() = b.shape(-1); out_shape.back() = b.shape(-1);
@ -3917,8 +3891,8 @@ array addmm(
} }
if (a.ndim() > 2 || b.ndim() > 2) { if (a.ndim() > 2 || b.ndim() > 2) {
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2); Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2); Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_a, bsx_b); auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
// Broadcast a // Broadcast a
@ -4042,8 +4016,8 @@ array block_masked_mm(
b = astype(b, out_type, s); b = astype(b, out_type, s);
// Handle broadcasting // Handle broadcasting
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2); Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2); Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto bsx_shape = broadcast_shapes(bsx_a, bsx_b); auto bsx_shape = broadcast_shapes(bsx_a, bsx_b);
@ -4079,7 +4053,7 @@ array block_masked_mm(
// Broadcast and astype mask // Broadcast and astype mask
auto broadcast_mask = [](array mask, auto broadcast_mask = [](array mask,
std::vector<int>& bs_shape, Shape& bs_shape,
int y, int y,
int x, int x,
Dtype mask_dtype, Dtype mask_dtype,
@ -4397,7 +4371,7 @@ std::vector<array> depends(
Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream() Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream()
: to_stream({}); : to_stream({});
// Make the output info // Make the output info
std::vector<std::vector<int>> shapes; std::vector<Shape> shapes;
std::vector<Dtype> dtypes; std::vector<Dtype> dtypes;
for (const auto& in : inputs) { for (const auto& in : inputs) {
shapes.emplace_back(in.shape()); shapes.emplace_back(in.shape());
@ -4434,7 +4408,7 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
case 0: case 0:
return reshape(a, {1, 1}, s); return reshape(a, {1, 1}, s);
case 1: case 1:
return reshape(a, {1, static_cast<int>(a.size())}, s); return reshape(a, {1, a.shape(0)}, s);
default: default:
return a; return a;
} }
@ -4456,7 +4430,7 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
case 0: case 0:
return reshape(a, {1, 1, 1}, s); return reshape(a, {1, 1, 1}, s);
case 1: case 1:
return reshape(a, {1, static_cast<int>(a.size()), 1}, s); return reshape(a, {1, a.shape(0), 1}, s);
case 2: case 2:
return reshape(a, {a.shape(0), a.shape(1), 1}, s); return reshape(a, {a.shape(0), a.shape(1), 1}, s);
default: default:
@ -4493,7 +4467,7 @@ array number_of_elements(
} }
return stop_gradient(array( return stop_gradient(array(
std::vector<int>{}, Shape{},
dtype, dtype,
std::make_shared<NumberOfElements>( std::make_shared<NumberOfElements>(
to_stream(s), std::move(axes), inverted, dtype), to_stream(s), std::move(axes), inverted, dtype),
@ -4613,7 +4587,7 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
array roll( array roll(
const array& a, const array& a,
const std::vector<int>& shift, const Shape& shift,
const std::vector<int>& axes, const std::vector<int>& axes,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (axes.empty()) { if (axes.empty()) {
@ -4627,7 +4601,6 @@ array roll(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
std::vector<array> parts;
array result = a; array result = a;
for (int i = 0; i < axes.size(); i++) { for (int i = 0; i < axes.size(); i++) {
int ax = axes[i]; int ax = axes[i];
@ -4641,11 +4614,11 @@ array roll(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
int sh = shift[i]; auto sh = shift[i];
int split_index = auto split_index =
(sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax);
parts = split(result, std::vector<int>{split_index}, ax, s); auto parts = split(result, Shape{split_index}, ax, s);
std::swap(parts[0], parts[1]); std::swap(parts[0], parts[1]);
result = concatenate(parts, ax, s); result = concatenate(parts, ax, s);
} }
@ -4656,19 +4629,12 @@ array roll(
array roll(const array& a, int shift, StreamOrDevice s /* = {} */) { array roll(const array& a, int shift, StreamOrDevice s /* = {} */) {
auto shape = a.shape(); auto shape = a.shape();
return reshape( return reshape(
roll( roll(reshape(a, Shape{-1}, s), Shape{shift}, std::vector<int>{0}, s),
reshape(a, std::vector<int>{-1}, s),
std::vector<int>{shift},
std::vector<int>{0},
s),
std::move(shape), std::move(shape),
s); s);
} }
array roll( array roll(const array& a, const Shape& shift, StreamOrDevice s /* = {} */) {
const array& a,
const std::vector<int>& shift,
StreamOrDevice s /* = {} */) {
int total_shift = 0; int total_shift = 0;
for (auto& s : shift) { for (auto& s : shift) {
total_shift += s; total_shift += s;
@ -4677,7 +4643,7 @@ array roll(
} }
array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) { array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) {
return roll(a, std::vector<int>{shift}, std::vector<int>{axis}, s); return roll(a, Shape{shift}, std::vector<int>{axis}, s);
} }
array roll( array roll(
@ -4685,20 +4651,20 @@ array roll(
int shift, int shift,
const std::vector<int>& axes, const std::vector<int>& axes,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
std::vector<int> shifts(axes.size(), shift); Shape shifts(axes.size(), shift);
return roll(a, shifts, axes, s); return roll(a, shifts, axes, s);
} }
array roll( array roll(
const array& a, const array& a,
const std::vector<int>& shift, const Shape& shift,
int axis, int axis,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
int total_shift = 0; int total_shift = 0;
for (auto& s : shift) { for (auto& s : shift) {
total_shift += s; total_shift += s;
} }
return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s); return roll(a, Shape{total_shift}, std::vector<int>{axis}, s);
} }
array real(const array& a, StreamOrDevice s /* = {} */) { array real(const array& a, StreamOrDevice s /* = {} */) {

View File

@ -49,8 +49,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s = {});
/** Create a view of an array with the given shape and strides. */ /** Create a view of an array with the given shape and strides. */
array as_strided( array as_strided(
array a, array a,
std::vector<int> shape, Shape shape,
std::vector<size_t> strides, Strides strides,
size_t offset, size_t offset,
StreamOrDevice s = {}); StreamOrDevice s = {});
@ -58,31 +58,27 @@ array as_strided(
array copy(array a, StreamOrDevice s = {}); array copy(array a, StreamOrDevice s = {});
/** Fill an array of the given shape with the given value(s). */ /** Fill an array of the given shape with the given value(s). */
array full( array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
std::vector<int> shape, array full(Shape shape, array vals, StreamOrDevice s = {});
array vals,
Dtype dtype,
StreamOrDevice s = {});
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
template <typename T> template <typename T>
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) { array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {
return full(std::move(shape), array(val, dtype), to_stream(s)); return full(std::move(shape), array(val, dtype), to_stream(s));
} }
template <typename T> template <typename T>
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) { array full(Shape shape, T val, StreamOrDevice s = {}) {
return full(std::move(shape), array(val), to_stream(s)); return full(std::move(shape), array(val), to_stream(s));
} }
/** Fill an array of the given shape with zeros. */ /** Fill an array of the given shape with zeros. */
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {}); array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) { inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
return zeros(shape, float32, s); return zeros(shape, float32, s);
} }
array zeros_like(const array& a, StreamOrDevice s = {}); array zeros_like(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape with ones. */ /** Fill an array of the given shape with ones. */
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {}); array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) { inline array ones(const Shape& shape, StreamOrDevice s = {}) {
return ones(shape, float32, s); return ones(shape, float32, s);
} }
array ones_like(const array& a, StreamOrDevice s = {}); array ones_like(const array& a, StreamOrDevice s = {});
@ -119,7 +115,7 @@ array tril(array x, int k = 0, StreamOrDevice s = {});
array triu(array x, int k = 0, StreamOrDevice s = {}); array triu(array x, int k = 0, StreamOrDevice s = {});
/** Reshape an array to the given shape. */ /** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {}); array reshape(const array& a, Shape shape, StreamOrDevice s = {});
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */ /** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten( array flatten(
@ -161,33 +157,29 @@ array expand_dims(const array& a, int axis, StreamOrDevice s = {});
/** Slice an array. */ /** Slice an array. */
array slice( array slice(
const array& a, const array& a,
std::vector<int> start, Shape start,
std::vector<int> stop, Shape stop,
std::vector<int> strides, Shape strides,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Slice an array with a stride of 1 in each dimension. */ /** Slice an array with a stride of 1 in each dimension. */
array slice( array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
const array& a,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s = {});
/** Update a slice from the source array */ /** Update a slice from the source array */
array slice_update( array slice_update(
const array& src, const array& src,
const array& update, const array& update,
std::vector<int> start, Shape start,
std::vector<int> stop, Shape stop,
std::vector<int> strides, Shape strides,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Update a slice from the source array with stride 1 in each dimension */ /** Update a slice from the source array with stride 1 in each dimension */
array slice_update( array slice_update(
const array& src, const array& src,
const array& update, const array& update,
std::vector<int> start, Shape start,
std::vector<int> stop, Shape stop,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */ /** Split an array into sub-arrays along a given axis. */
@ -288,10 +280,7 @@ array pad(
array transpose(const array& a, StreamOrDevice s = {}); array transpose(const array& a, StreamOrDevice s = {});
/** Broadcast an array to a given shape. */ /** Broadcast an array to a given shape. */
array broadcast_to( array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});
const array& a,
const std::vector<int>& shape,
StreamOrDevice s = {});
/** Broadcast a vector of arrays against one another. */ /** Broadcast a vector of arrays against one another. */
std::vector<array> broadcast_arrays( std::vector<array> broadcast_arrays(
@ -917,13 +906,13 @@ array gather(
const array& a, const array& a,
const std::vector<array>& indices, const std::vector<array>& indices,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& slice_sizes, const Shape& slice_sizes,
StreamOrDevice s = {}); StreamOrDevice s = {});
inline array gather( inline array gather(
const array& a, const array& a,
const array& indices, const array& indices,
int axis, int axis,
const std::vector<int>& slice_sizes, const Shape& slice_sizes,
StreamOrDevice s = {}) { StreamOrDevice s = {}) {
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s); return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
} }
@ -1459,24 +1448,13 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
/** Roll elements along an axis and introduce them on the other side */ /** Roll elements along an axis and introduce them on the other side */
array roll(const array& a, int shift, StreamOrDevice s = {}); array roll(const array& a, int shift, StreamOrDevice s = {});
array roll( array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
const array& a,
const std::vector<int>& shift,
StreamOrDevice s = {});
array roll(const array& a, int shift, int axis, StreamOrDevice s = {}); array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
array roll( array roll(
const array& a, const array& a,
int shift, const Shape& shift,
const std::vector<int>& axes,
StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
int axis,
StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
const std::vector<int>& axes, const std::vector<int>& axes,
StreamOrDevice s = {}); StreamOrDevice s = {});

View File

@ -66,9 +66,7 @@ Dtype result_type(const std::vector<array>& arrays) {
return t; return t;
} }
std::vector<int> broadcast_shapes( Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
const std::vector<int>& s1,
const std::vector<int>& s2) {
// Use the same broadcasting rules as numpy // Use the same broadcasting rules as numpy
// https://numpy.org/doc/1.20/user/theory.broadcasting.html // https://numpy.org/doc/1.20/user/theory.broadcasting.html
// "The size of the trailing axes for both arrays in an operation must // "The size of the trailing axes for both arrays in an operation must
@ -79,7 +77,7 @@ std::vector<int> broadcast_shapes(
int diff = std::abs(ndim1 - ndim2); int diff = std::abs(ndim1 - ndim2);
const auto& big = ndim1 > ndim2 ? s1 : s2; const auto& big = ndim1 > ndim2 ? s1 : s2;
const auto& small = ndim1 > ndim2 ? s2 : s1; const auto& small = ndim1 > ndim2 ? s2 : s1;
std::vector<int> out_shape(ndim); Shape out_shape(ndim);
for (int i = ndim - 1; i >= diff; --i) { for (int i = ndim - 1; i >= diff; --i) {
int a = big[i]; int a = big[i];
int b = small[i - diff]; int b = small[i - diff];
@ -158,10 +156,8 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
namespace { namespace {
inline size_t elem_to_loc( inline size_t
int elem, elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
size_t loc = 0; size_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) { for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]); auto q_and_r = ldiv(elem, shape[i]);
@ -199,7 +195,6 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
template <typename T> template <typename T>
void print_array(std::ostream& os, const array& a) { void print_array(std::ostream& os, const array& a) {
std::vector<int> indices(a.ndim(), 0);
os << std::boolalpha; os << std::boolalpha;
os << "array("; os << "array(";
if (a.ndim() == 0) { if (a.ndim() == 0) {
@ -310,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) { std::ostream& operator<<(std::ostream& os, const Shape& v) {
os << "("; os << "(";
for (int i = 0; i < v.size(); ++i) { for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ","); os << v[i] << ((i == v.size() - 1) ? "" : ",");
@ -319,7 +314,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) { std::ostream& operator<<(std::ostream& os, const Strides& v) {
os << "("; os << "(";
for (int i = 0; i < v.size(); ++i) { for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ","); os << v[i] << ((i == v.size() - 1) ? "" : ",");

View File

@ -62,9 +62,7 @@ inline Dtype result_type(const array& a, const array& b, const array& c) {
} }
Dtype result_type(const std::vector<array>& arrays); Dtype result_type(const std::vector<array>& arrays);
std::vector<int> broadcast_shapes( Shape broadcast_shapes(const Shape& s1, const Shape& s2);
const std::vector<int>& s1,
const std::vector<int>& s2);
bool is_same_shape(const std::vector<array>& arrays); bool is_same_shape(const std::vector<array>& arrays);
@ -96,8 +94,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d); std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a); std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v); std::ostream& operator<<(std::ostream& os, const Shape& v);
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v); std::ostream& operator<<(std::ostream& os, const Strides& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v); std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";