Shape and Strides 1 / N (#1645)

* shape and stride type def

* more shape
This commit is contained in:
Awni Hannun 2024-12-05 12:53:43 -08:00 committed by GitHub
parent c5b0928c1f
commit fc88fd9097
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 178 additions and 242 deletions

View File

@ -31,7 +31,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
}
array::array(
std::vector<int> shape,
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
@ -42,7 +42,7 @@ array::array(
std::move(inputs))) {}
std::vector<array> array::make_arrays(
std::vector<std::vector<int>> shapes,
std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
@ -74,11 +74,7 @@ array::array(std::initializer_list<int> data, Dtype dtype)
}
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter)
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter);
}
@ -126,7 +122,7 @@ bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph();
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size();
@ -139,9 +135,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) {
void array::set_data(
allocator::Buffer buffer,
size_t data_size,
std::vector<size_t> strides,
Strides strides,
Flags flags,
deleter_t d) {
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size;
@ -151,7 +147,7 @@ void array::set_data(
void array::copy_shared_buffer(
const array& other,
const std::vector<size_t>& strides,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@ -170,7 +166,7 @@ void array::copy_shared_buffer(const array& other) {
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@ -237,13 +233,13 @@ void array::ArrayDesc::init() {
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
}
array::ArrayDesc::ArrayDesc(
std::vector<int> shape,
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)

View File

@ -15,7 +15,10 @@ namespace mlx::core {
// Forward declaration
class Primitive;
using deleter_t = std::function<void(allocator::Buffer)>;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using Strides = std::vector<size_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc
@ -33,7 +36,7 @@ class array {
template <typename It>
array(
It data,
std::vector<int> shape,
Shape shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@ -49,15 +52,15 @@ class array {
template <typename T>
array(
std::initializer_list<T> data,
std::vector<int> shape,
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
std::vector<int> shape,
Shape shape,
Dtype dtype,
deleter_t deleter = allocator::free);
Deleter deleter = allocator::free);
/** Assignment to rvalue does not compile. */
array& operator=(const array& other) && = delete;
@ -96,7 +99,7 @@ class array {
}
/** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const {
const Shape& shape() const {
return array_desc_->shape;
}
@ -105,12 +108,12 @@ class array {
*
* This function supports negative indexing and provides
* bounds checking. */
int shape(int dim) const {
auto shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim);
}
/** The strides of the array. */
const std::vector<size_t>& strides() const {
const Strides& strides() const {
return array_desc_->strides;
}
@ -119,7 +122,7 @@ class array {
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
auto strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
}
@ -184,13 +187,13 @@ class array {
*/
array(
std::vector<int> shape,
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
static std::vector<array> make_arrays(
std::vector<std::vector<int>> shapes,
std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
@ -207,8 +210,8 @@ class array {
struct Data {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
Deleter d;
Data(allocator::Buffer buffer, Deleter d = allocator::free)
: buffer(buffer), d(d) {}
// Not copyable
Data(const Data& d) = delete;
@ -397,18 +400,18 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
void set_data(
allocator::Buffer buffer,
size_t data_size,
std::vector<size_t> strides,
Strides strides,
Flags flags,
deleter_t d = allocator::free);
Deleter d = allocator::free);
void copy_shared_buffer(
const array& other,
const std::vector<size_t>& strides,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
@ -417,7 +420,7 @@ class array {
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
@ -436,8 +439,8 @@ class array {
void init(const It src);
struct ArrayDesc {
std::vector<int> shape;
std::vector<size_t> strides;
Shape shape;
Strides strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive;
@ -471,10 +474,10 @@ class array {
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
explicit ArrayDesc(Shape shape, Dtype dtype);
explicit ArrayDesc(
std::vector<int> shape,
Shape shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
@ -502,7 +505,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
std::vector<int> shape,
Shape shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data);
@ -521,7 +524,7 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
std::vector<int> shape,
Shape shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) {

View File

@ -16,10 +16,9 @@ namespace mlx::core {
namespace {
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool>
compute_reduce_shape(
std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
const std::vector<int>& axes,
const std::vector<int>& shape) {
const Shape& shape) {
bool is_noop = true;
std::set<int> axes_set;
auto ndim = shape.size();
@ -36,8 +35,8 @@ compute_reduce_shape(
if (axes_set.size() != axes.size()) {
throw std::invalid_argument("Duplicate axes detected in reduction.");
}
std::vector<int> out_shape;
std::vector<int> squeezed_shape;
Shape out_shape;
Shape squeezed_shape;
for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]);
@ -63,7 +62,7 @@ array indices_or_default(
return indices.value();
}
std::vector<int> shape(x.shape().begin(), x.shape().end() - 2);
Shape shape(x.shape().begin(), x.shape().end() - 2);
int total =
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
return reshape(arange(total, uint32, s), shape, s);
@ -254,8 +253,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
array as_strided(
array a,
std::vector<int> shape,
std::vector<size_t> strides,
Shape shape,
Strides strides,
size_t offset,
StreamOrDevice s /* = {} */) {
auto copied_shape = shape; // |shape| will be moved
@ -279,12 +278,8 @@ array copy(array a, StreamOrDevice s /* = {} */) {
{std::move(a)});
}
array full(
std::vector<int> shape,
array vals,
Dtype dtype,
StreamOrDevice s /* = {} */) {
if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) {
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
throw std::invalid_argument("[full] Negative dimensions not allowed.");
}
auto copied_shape = shape; // |shape| will be moved
@ -295,15 +290,12 @@ array full(
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
}
array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) {
array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
auto dtype = vals.dtype(); // |vals| will be moved
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
}
array zeros(
const std::vector<int>& shape,
Dtype dtype,
StreamOrDevice s /* = {} */) {
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
return full(shape, array(0, dtype), to_stream(s));
}
@ -311,10 +303,7 @@ array zeros_like(const array& a, StreamOrDevice s /* = {} */) {
return zeros(a.shape(), a.dtype(), to_stream(s));
}
array ones(
const std::vector<int>& shape,
Dtype dtype,
StreamOrDevice s /* = {} */) {
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
return full(shape, array(1, dtype), to_stream(s));
}
@ -368,10 +357,7 @@ array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
return where(mask, zeros_like(x, s), x, s);
}
array reshape(
const array& a,
std::vector<int> shape,
StreamOrDevice s /* = {} */) {
array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
if (a.shape() == shape) {
return a;
}
@ -445,11 +431,11 @@ array flatten(
if (start_ax == end_ax) {
return a;
}
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_ax);
Shape new_shape(a.shape().begin(), a.shape().begin() + start_ax);
new_shape.push_back(-1);
new_shape.insert(
new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end());
return reshape(a, new_shape, s);
return reshape(a, std::move(new_shape), s);
}
array flatten(const array& a, StreamOrDevice s /* = {} */) {
@ -496,7 +482,7 @@ array squeeze(
throw std::invalid_argument("[squeeze] Received duplicate axes.");
}
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
std::vector<int> shape;
Shape shape;
for (int i = 0, j = 0; i < a.ndim(); ++i) {
if (j < sorted_axes.size() && i == sorted_axes[j]) {
j++;
@ -584,12 +570,9 @@ array expand_dims(
// Slice helper
namespace {
inline auto normalize_slice(
const std::vector<int>& shape,
std::vector<int>& start,
std::vector<int>& stop,
std::vector<int>& strides) {
std::vector<int> out_shape(shape.size());
inline auto
normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
Shape out_shape(shape.size());
bool has_neg_strides = false;
for (int i = 0; i < shape.size(); ++i) {
@ -641,9 +624,9 @@ inline auto normalize_slice(
array slice(
const array& a,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
Shape start,
Shape stop,
Shape strides,
StreamOrDevice s /* = {} */) {
if (start.size() != a.ndim() || stop.size() != a.ndim() ||
strides.size() != a.ndim()) {
@ -670,24 +653,20 @@ array slice(
array slice(
const array& a,
std::vector<int> start,
std::vector<int> stop,
Shape start,
Shape stop,
StreamOrDevice s /* = {} */) {
return slice(
a,
std::move(start),
std::move(stop),
std::vector<int>(a.ndim(), 1),
to_stream(s));
a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s));
}
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
Shape start,
Shape stop,
Shape strides,
StreamOrDevice s /* = {} */) {
// Check dimensions
if (start.size() != src.ndim() || stop.size() != src.ndim() ||
@ -721,12 +700,11 @@ array slice_update(
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
Shape start,
Shape stop,
StreamOrDevice s /* = {} */) {
auto strides = std::vector<int>(src.ndim(), 1);
return slice_update(
src, update, std::move(start), std::move(stop), std::move(strides), s);
src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);
}
std::vector<array> split(
@ -750,7 +728,7 @@ std::vector<array> split(
std::is_sorted(indices.begin(), indices.end(), std::less<>{}) &&
indices[0] > 0 && indices.back() < a.shape(ax)) {
std::vector<Dtype> dtypes(indices.size() + 1, a.dtype());
std::vector<std::vector<int>> shapes(indices.size() + 1, a.shape());
std::vector<Shape> shapes(indices.size() + 1, a.shape());
shapes[0][ax] = indices[0];
for (int i = 1; i < indices.size(); i++) {
shapes[i][ax] = indices[i] - indices[i - 1];
@ -765,8 +743,7 @@ std::vector<array> split(
}
std::vector<array> res;
auto out_shape = a.shape();
auto start_indices = std::vector<int>(a.ndim(), 0);
auto start_indices = Shape(a.ndim(), 0);
auto stop_indices = a.shape();
for (int i = 0; i < indices.size() + 1; ++i) {
stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax);
@ -826,13 +803,13 @@ std::vector<array> meshgrid(
auto ndim = arrays.size();
std::vector<array> outputs;
for (int i = 0; i < ndim; ++i) {
std::vector<int> shape(ndim, 1);
Shape shape(ndim, 1);
shape[i] = -1;
outputs.push_back(reshape(arrays[i], std::move(shape), s));
}
if (indexing == "xy" and ndim > 1) {
std::vector<int> shape(ndim, 1);
Shape shape(ndim, 1);
shape[1] = arrays[0].size();
outputs[0] = reshape(arrays[0], shape, s);
@ -895,7 +872,7 @@ array concatenate(
throw std::invalid_argument(msg.str());
};
std::vector<int> shape = arrays[0].shape();
auto shape = arrays[0].shape();
shape[ax] = 0;
// Make the output shape and validate that all arrays have the same shape
// except for the concatenation axis.
@ -980,7 +957,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
}
// Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)
std::vector<int> shape(arr.shape());
auto shape = arr.shape();
shape.insert(shape.begin() + axis + 1, repeats);
array out = expand_dims(arr, axis + 1, s);
out = broadcast_to(out, shape, s);
@ -1009,9 +986,9 @@ array tile(
shape.insert(shape.begin(), reps.size() - shape.size(), 1);
}
std::vector<int> expand_shape;
std::vector<int> broad_shape;
std::vector<int> final_shape;
Shape expand_shape;
Shape broad_shape;
Shape final_shape;
for (int i = 0; i < shape.size(); i++) {
if (reps[i] != 1) {
expand_shape.push_back(1);
@ -1022,17 +999,17 @@ array tile(
final_shape.push_back(reps[i] * shape[i]);
}
auto x = reshape(arr, expand_shape, s);
x = broadcast_to(x, broad_shape, s);
return reshape(x, final_shape, s);
auto x = reshape(arr, std::move(expand_shape), s);
x = broadcast_to(x, std::move(broad_shape), s);
return reshape(x, std::move(final_shape), s);
}
array edge_pad(
const array& a,
const std::vector<int>& axes,
const std::vector<int>& low_pad_size,
const std::vector<int>& high_pad_size,
const std::vector<int>& out_shape,
const Shape& low_pad_size,
const Shape& high_pad_size,
const Shape& out_shape,
StreamOrDevice s /* = {}*/) {
array out = zeros(out_shape, a.dtype(), s);
auto stops = a.shape();
@ -1044,7 +1021,7 @@ array edge_pad(
for (int axis = 0; axis < a.ndim(); axis++) {
if (low_pad_size[axis] > 0) {
std::vector<int> starts(a.ndim(), 0);
Shape starts(a.ndim(), 0);
starts[axis] = low_pad_size[axis];
auto stops = out.shape();
stops[axis] = low_pad_size[axis] + 1;
@ -1058,7 +1035,7 @@ array edge_pad(
}
if (high_pad_size[axis] > 0) {
std::vector<int> starts(a.ndim(), 0);
Shape starts(a.ndim(), 0);
starts[axis] = -high_pad_size[axis] - 1;
auto stops = out.shape();
stops[axis] = -high_pad_size[axis];
@ -1075,9 +1052,9 @@ array edge_pad(
/** Pad an array with a constant value */
array pad(
const array& a,
const std::vector<int>& axes,
const std::vector<int>& low_pad_size,
const std::vector<int>& high_pad_size,
const Shape& axes,
const Shape& low_pad_size,
const Shape& high_pad_size,
const array& pad_value /*= array(0)*/,
const std::string mode /*= "constant"*/,
StreamOrDevice s /* = {}*/) {
@ -1089,7 +1066,7 @@ array pad(
throw std::invalid_argument(msg.str());
}
std::vector<int> out_shape = a.shape();
auto out_shape = a.shape();
for (int i = 0; i < axes.size(); i++) {
if (low_pad_size[i] < 0) {
@ -1113,7 +1090,7 @@ array pad(
if (mode == "constant") {
return array(
out_shape,
std::move(out_shape),
a.dtype(),
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
{a, astype(pad_value, a.dtype(), s)});
@ -1136,8 +1113,8 @@ array pad(
std::vector<int> axes(a.ndim(), 0);
std::iota(axes.begin(), axes.end(), 0);
std::vector<int> lows;
std::vector<int> highs;
Shape lows;
Shape highs;
for (auto& pads : pad_width) {
lows.push_back(pads.first);
@ -1240,7 +1217,7 @@ array transpose(
}
// Check in bounds and for duplicates
std::vector<int> shape(axes.size(), 0);
Shape shape(axes.size(), 0);
for (auto& ax : axes) {
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
@ -1272,7 +1249,7 @@ array transpose(const array& a, StreamOrDevice s /* = {} */) {
array broadcast_to(
const array& a,
const std::vector<int>& shape,
const Shape& shape,
StreamOrDevice s /* = {} */) {
if (a.shape() == shape) {
return a;
@ -1295,14 +1272,14 @@ array broadcast_to(
std::vector<array>
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
std::vector<int> shape = broadcast_shapes(a.shape(), b.shape());
auto shape = broadcast_shapes(a.shape(), b.shape());
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
}
std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
StreamOrDevice s /* = {} */) {
std::vector<int> shape{};
Shape shape{};
for (const auto& in : inputs) {
shape = broadcast_shapes(shape, in.shape());
}
@ -1913,7 +1890,7 @@ array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size();
auto result = argmax(reshape(a, {size}, s), 0, true, s);
if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
result = reshape(result, Shape(a.shape().size(), 1), s);
} else {
result = squeeze(result, s);
}
@ -2098,8 +2075,8 @@ array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
}
array a_partitioned = partition(a, -k, axis_, s);
std::vector<int> slice_starts(a.ndim(), 0);
std::vector<int> slice_ends = a.shape();
Shape slice_starts(a.ndim(), 0);
auto slice_ends = a.shape();
slice_starts[axis_] = a.shape(axis_) - k;
return slice(a_partitioned, slice_starts, slice_ends, s);
}
@ -2613,8 +2590,8 @@ array matmul(
}
if (a.ndim() > 2 || b.ndim() > 2) {
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
// Broadcast a
@ -2648,7 +2625,7 @@ array gather(
const array& a,
const std::vector<array>& indices,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes,
const Shape& slice_sizes,
StreamOrDevice s /* = {} */) {
// Checks that indices, dimensions, and slice_sizes are all valid
if (indices.size() > a.ndim()) {
@ -2703,7 +2680,7 @@ array gather(
idx = astype(idx, dtype, s);
}
std::vector<int> out_shape;
Shape out_shape;
if (!inputs.empty()) {
out_shape = inputs[0].shape();
}
@ -2741,7 +2718,7 @@ array take(
axis = axis < 0 ? a.ndim() + axis : axis;
// Make slice sizes to pass to gather
std::vector<int> slice_sizes = a.shape();
Shape slice_sizes = a.shape();
slice_sizes[axis] = indices.size() > 0 ? 1 : 0;
auto out = gather(a, indices, axis, slice_sizes, s);
@ -2759,7 +2736,7 @@ array take(
}
// Squeeze the axis we take over
std::vector<int> out_shape = out.shape();
auto out_shape = out.shape();
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
return reshape(out, std::move(out_shape), s);
}
@ -2787,8 +2764,8 @@ array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
// Handle negative axis
axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<int> starts(a.ndim(), 0);
std::vector<int> stops = a.shape();
Shape starts(a.ndim(), 0);
Shape stops = a.shape();
starts[axis] = index;
stops[axis] = index + 1;
return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);
@ -2821,7 +2798,7 @@ array take_along_axis(
axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<array> nd_indices;
std::vector<int> index_shape(a.ndim(), 1);
Shape index_shape(a.ndim(), 1);
for (int i = 0; i < a.ndim(); ++i) {
if (i == axis) {
nd_indices.push_back(indices);
@ -2834,12 +2811,11 @@ array take_along_axis(
}
std::vector<int> dims(a.ndim());
std::iota(dims.begin(), dims.end(), 0);
std::vector<int> slice_sizes(a.ndim(), a.size() > 0);
Shape slice_sizes(a.ndim(), a.size() > 0);
auto out = gather(a, nd_indices, dims, slice_sizes, s);
// Squeeze out the slice shape
std::vector<int> out_shape(
out.shape().begin(), out.shape().begin() + a.ndim());
Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim());
return reshape(out, std::move(out_shape), s);
}
@ -2867,7 +2843,7 @@ array put_along_axis(
axis = axis < 0 ? a.ndim() + axis : axis;
std::vector<array> nd_indices;
std::vector<int> index_shape(a.ndim(), 1);
Shape index_shape(a.ndim(), 1);
for (int i = 0; i < a.ndim(); ++i) {
if (i == axis) {
nd_indices.push_back(indices);
@ -2927,7 +2903,7 @@ array scatter(
// Broadcast and cast indices if necessary
auto inputs = broadcast_arrays(indices);
std::vector<int> idx_shape;
Shape idx_shape;
if (!inputs.empty()) {
idx_shape = inputs[0].shape();
}
@ -3198,7 +3174,7 @@ inline int dilate_size(int dim, int dil) {
return 1 + dil * (dim - 1);
}
inline std::vector<int> conv_out_shape(
Shape conv_out_shape(
const std::vector<int>& in_shape,
const std::vector<int>& wt_shape,
const std::vector<int>& strides,
@ -3208,7 +3184,7 @@ inline std::vector<int> conv_out_shape(
const std::vector<int>& input_dilation) {
int N = in_shape[0];
int O = wt_shape[0];
std::vector<int> out_shape(in_shape.size());
Shape out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;
@ -3577,8 +3553,8 @@ array conv_general(
// Handle negative padding
if (has_neg_padding) {
std::vector<int> starts(in.ndim(), 0);
std::vector<int> stops = in.shape();
Shape starts(in.ndim(), 0);
auto stops = in.shape();
for (int i = 0; i < spatial_dims; i++) {
if (padding_lo[i] < 0) {
@ -3596,7 +3572,7 @@ array conv_general(
}
// Get output shapes
std::vector<int> out_shape = conv_out_shape(
auto out_shape = conv_out_shape(
in.shape(),
wt.shape(),
stride,
@ -3606,7 +3582,7 @@ array conv_general(
input_dilation);
return array(
out_shape,
std::move(out_shape),
in.dtype(),
std::make_shared<Convolution>(
to_stream(s),
@ -3634,8 +3610,8 @@ array quantized_matmul(
// QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) {
std::vector<int> bsx_x(x.shape().begin(), x.shape().end() - 2);
std::vector<int> bsx_w(w.shape().begin(), w.shape().end() - 2);
Shape bsx_x(x.shape().begin(), x.shape().end() - 2);
Shape bsx_w(w.shape().begin(), w.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
// Broadcast x
@ -3731,7 +3707,7 @@ array gather_qmm(
// and output type
auto out_type = result_type(x, scales, biases);
auto out = array(
return array(
std::move(out_shape),
out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
@ -3741,8 +3717,6 @@ array gather_qmm(
astype(biases, out_type, s),
lhs_indices,
rhs_indices});
return out;
}
array tensordot(
@ -3802,7 +3776,7 @@ array tensordot(
std::vector<int> t1;
std::vector<int> t2;
std::vector<int> rshape;
Shape rshape;
int size1 = 1;
int size2 = 1;
for (int i = 0; i < a.ndim(); i++) {
@ -3898,7 +3872,7 @@ array addmm(
// We can batch the multiplication by reshaping a
if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) {
std::vector<int> out_shape = a.shape();
auto out_shape = a.shape();
a = reshape(a, {-1, out_shape.back()}, s);
out_shape.back() = b.shape(-1);
@ -3917,8 +3891,8 @@ array addmm(
}
if (a.ndim() > 2 || b.ndim() > 2) {
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
// Broadcast a
@ -4042,8 +4016,8 @@ array block_masked_mm(
b = astype(b, out_type, s);
// Handle broadcasting
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto bsx_shape = broadcast_shapes(bsx_a, bsx_b);
@ -4079,7 +4053,7 @@ array block_masked_mm(
// Broadcast and astype mask
auto broadcast_mask = [](array mask,
std::vector<int>& bs_shape,
Shape& bs_shape,
int y,
int x,
Dtype mask_dtype,
@ -4397,7 +4371,7 @@ std::vector<array> depends(
Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream()
: to_stream({});
// Make the output info
std::vector<std::vector<int>> shapes;
std::vector<Shape> shapes;
std::vector<Dtype> dtypes;
for (const auto& in : inputs) {
shapes.emplace_back(in.shape());
@ -4434,7 +4408,7 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
case 0:
return reshape(a, {1, 1}, s);
case 1:
return reshape(a, {1, static_cast<int>(a.size())}, s);
return reshape(a, {1, a.shape(0)}, s);
default:
return a;
}
@ -4456,7 +4430,7 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
case 0:
return reshape(a, {1, 1, 1}, s);
case 1:
return reshape(a, {1, static_cast<int>(a.size()), 1}, s);
return reshape(a, {1, a.shape(0), 1}, s);
case 2:
return reshape(a, {a.shape(0), a.shape(1), 1}, s);
default:
@ -4493,7 +4467,7 @@ array number_of_elements(
}
return stop_gradient(array(
std::vector<int>{},
Shape{},
dtype,
std::make_shared<NumberOfElements>(
to_stream(s), std::move(axes), inverted, dtype),
@ -4613,7 +4587,7 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
array roll(
const array& a,
const std::vector<int>& shift,
const Shape& shift,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
@ -4627,7 +4601,6 @@ array roll(
throw std::invalid_argument(msg.str());
}
std::vector<array> parts;
array result = a;
for (int i = 0; i < axes.size(); i++) {
int ax = axes[i];
@ -4641,11 +4614,11 @@ array roll(
throw std::invalid_argument(msg.str());
}
int sh = shift[i];
int split_index =
auto sh = shift[i];
auto split_index =
(sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax);
parts = split(result, std::vector<int>{split_index}, ax, s);
auto parts = split(result, Shape{split_index}, ax, s);
std::swap(parts[0], parts[1]);
result = concatenate(parts, ax, s);
}
@ -4656,19 +4629,12 @@ array roll(
array roll(const array& a, int shift, StreamOrDevice s /* = {} */) {
auto shape = a.shape();
return reshape(
roll(
reshape(a, std::vector<int>{-1}, s),
std::vector<int>{shift},
std::vector<int>{0},
s),
roll(reshape(a, Shape{-1}, s), Shape{shift}, std::vector<int>{0}, s),
std::move(shape),
s);
}
array roll(
const array& a,
const std::vector<int>& shift,
StreamOrDevice s /* = {} */) {
array roll(const array& a, const Shape& shift, StreamOrDevice s /* = {} */) {
int total_shift = 0;
for (auto& s : shift) {
total_shift += s;
@ -4677,7 +4643,7 @@ array roll(
}
array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) {
return roll(a, std::vector<int>{shift}, std::vector<int>{axis}, s);
return roll(a, Shape{shift}, std::vector<int>{axis}, s);
}
array roll(
@ -4685,20 +4651,20 @@ array roll(
int shift,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
std::vector<int> shifts(axes.size(), shift);
Shape shifts(axes.size(), shift);
return roll(a, shifts, axes, s);
}
array roll(
const array& a,
const std::vector<int>& shift,
const Shape& shift,
int axis,
StreamOrDevice s /* = {} */) {
int total_shift = 0;
for (auto& s : shift) {
total_shift += s;
}
return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s);
return roll(a, Shape{total_shift}, std::vector<int>{axis}, s);
}
array real(const array& a, StreamOrDevice s /* = {} */) {

View File

@ -49,8 +49,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s = {});
/** Create a view of an array with the given shape and strides. */
array as_strided(
array a,
std::vector<int> shape,
std::vector<size_t> strides,
Shape shape,
Strides strides,
size_t offset,
StreamOrDevice s = {});
@ -58,31 +58,27 @@ array as_strided(
array copy(array a, StreamOrDevice s = {});
/** Fill an array of the given shape with the given value(s). */
array full(
std::vector<int> shape,
array vals,
Dtype dtype,
StreamOrDevice s = {});
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
array full(Shape shape, array vals, StreamOrDevice s = {});
template <typename T>
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {
return full(std::move(shape), array(val, dtype), to_stream(s));
}
template <typename T>
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
array full(Shape shape, T val, StreamOrDevice s = {}) {
return full(std::move(shape), array(val), to_stream(s));
}
/** Fill an array of the given shape with zeros. */
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
return zeros(shape, float32, s);
}
array zeros_like(const array& a, StreamOrDevice s = {});
/** Fill an array of the given shape with ones. */
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
inline array ones(const Shape& shape, StreamOrDevice s = {}) {
return ones(shape, float32, s);
}
array ones_like(const array& a, StreamOrDevice s = {});
@ -119,7 +115,7 @@ array tril(array x, int k = 0, StreamOrDevice s = {});
array triu(array x, int k = 0, StreamOrDevice s = {});
/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten(
@ -161,33 +157,29 @@ array expand_dims(const array& a, int axis, StreamOrDevice s = {});
/** Slice an array. */
array slice(
const array& a,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
Shape start,
Shape stop,
Shape strides,
StreamOrDevice s = {});
/** Slice an array with a stride of 1 in each dimension. */
array slice(
const array& a,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s = {});
array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
Shape start,
Shape stop,
Shape strides,
StreamOrDevice s = {});
/** Update a slice from the source array with stride 1 in each dimension */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
Shape start,
Shape stop,
StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */
@ -288,10 +280,7 @@ array pad(
array transpose(const array& a, StreamOrDevice s = {});
/** Broadcast an array to a given shape. */
array broadcast_to(
const array& a,
const std::vector<int>& shape,
StreamOrDevice s = {});
array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});
/** Broadcast a vector of arrays against one another. */
std::vector<array> broadcast_arrays(
@ -917,13 +906,13 @@ array gather(
const array& a,
const std::vector<array>& indices,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes,
const Shape& slice_sizes,
StreamOrDevice s = {});
inline array gather(
const array& a,
const array& indices,
int axis,
const std::vector<int>& slice_sizes,
const Shape& slice_sizes,
StreamOrDevice s = {}) {
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
}
@ -1459,24 +1448,13 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
/** Roll elements along an axis and introduce them on the other side */
array roll(const array& a, int shift, StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
array roll(
const array& a,
int shift,
const std::vector<int>& axes,
StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
int axis,
StreamOrDevice s = {});
array roll(
const array& a,
const std::vector<int>& shift,
const Shape& shift,
const std::vector<int>& axes,
StreamOrDevice s = {});

View File

@ -66,9 +66,7 @@ Dtype result_type(const std::vector<array>& arrays) {
return t;
}
std::vector<int> broadcast_shapes(
const std::vector<int>& s1,
const std::vector<int>& s2) {
Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
// Use the same broadcasting rules as numpy
// https://numpy.org/doc/1.20/user/theory.broadcasting.html
// "The size of the trailing axes for both arrays in an operation must
@ -79,7 +77,7 @@ std::vector<int> broadcast_shapes(
int diff = std::abs(ndim1 - ndim2);
const auto& big = ndim1 > ndim2 ? s1 : s2;
const auto& small = ndim1 > ndim2 ? s2 : s1;
std::vector<int> out_shape(ndim);
Shape out_shape(ndim);
for (int i = ndim - 1; i >= diff; --i) {
int a = big[i];
int b = small[i - diff];
@ -158,10 +156,8 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
namespace {
inline size_t elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
inline size_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
size_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
@ -199,7 +195,6 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
template <typename T>
void print_array(std::ostream& os, const array& a) {
std::vector<int> indices(a.ndim(), 0);
os << std::boolalpha;
os << "array(";
if (a.ndim() == 0) {
@ -310,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
std::ostream& operator<<(std::ostream& os, const Shape& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
@ -319,7 +314,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
std::ostream& operator<<(std::ostream& os, const Strides& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");

View File

@ -62,9 +62,7 @@ inline Dtype result_type(const array& a, const array& b, const array& c) {
}
Dtype result_type(const std::vector<array>& arrays);
std::vector<int> broadcast_shapes(
const std::vector<int>& s1,
const std::vector<int>& s2);
Shape broadcast_shapes(const Shape& s1, const Shape& s2);
bool is_same_shape(const std::vector<array>& arrays);
@ -96,8 +94,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
std::ostream& operator<<(std::ostream& os, const Shape& v);
std::ostream& operator<<(std::ostream& os, const Strides& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";