mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Shape and Strides 1 / N (#1645)
* shape and stride type def * more shape
This commit is contained in:
parent
c5b0928c1f
commit
fc88fd9097
@ -31,7 +31,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
}
|
||||
|
||||
array::array(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
@ -42,7 +42,7 @@ array::array(
|
||||
std::move(inputs))) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
std::vector<std::vector<int>> shapes,
|
||||
std::vector<Shape> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
@ -74,11 +74,7 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
@ -126,7 +122,7 @@ bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = size();
|
||||
@ -139,9 +135,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
void array::set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Strides strides,
|
||||
Flags flags,
|
||||
deleter_t d) {
|
||||
Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = data_size;
|
||||
@ -151,7 +147,7 @@ void array::set_data(
|
||||
|
||||
void array::copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@ -170,7 +166,7 @@ void array::copy_shared_buffer(const array& other) {
|
||||
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@ -237,13 +233,13 @@ void array::ArrayDesc::init() {
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
|
51
mlx/array.h
51
mlx/array.h
@ -15,7 +15,10 @@ namespace mlx::core {
|
||||
|
||||
// Forward declaration
|
||||
class Primitive;
|
||||
using deleter_t = std::function<void(allocator::Buffer)>;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using Shape = std::vector<int32_t>;
|
||||
using Strides = std::vector<size_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@ -33,7 +36,7 @@ class array {
|
||||
template <typename It>
|
||||
array(
|
||||
It data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
@ -49,15 +52,15 @@ class array {
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter = allocator::free);
|
||||
Deleter deleter = allocator::free);
|
||||
|
||||
/** Assignment to rvalue does not compile. */
|
||||
array& operator=(const array& other) && = delete;
|
||||
@ -96,7 +99,7 @@ class array {
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
const Shape& shape() const {
|
||||
return array_desc_->shape;
|
||||
}
|
||||
|
||||
@ -105,12 +108,12 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
auto shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
const Strides& strides() const {
|
||||
return array_desc_->strides;
|
||||
}
|
||||
|
||||
@ -119,7 +122,7 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
auto strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
@ -184,13 +187,13 @@ class array {
|
||||
*/
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
std::vector<std::vector<int>> shapes,
|
||||
std::vector<Shape> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
@ -207,8 +210,8 @@ class array {
|
||||
|
||||
struct Data {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
Deleter d;
|
||||
Data(allocator::Buffer buffer, Deleter d = allocator::free)
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
@ -397,18 +400,18 @@ class array {
|
||||
// Check if the array is a tracer array
|
||||
bool is_tracer() const;
|
||||
|
||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
|
||||
|
||||
void set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Strides strides,
|
||||
Flags flags,
|
||||
deleter_t d = allocator::free);
|
||||
Deleter d = allocator::free);
|
||||
|
||||
void copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
@ -417,7 +420,7 @@ class array {
|
||||
|
||||
void move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
@ -436,8 +439,8 @@ class array {
|
||||
void init(const It src);
|
||||
|
||||
struct ArrayDesc {
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
@ -471,10 +474,10 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
explicit ArrayDesc(Shape shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
@ -502,7 +505,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
@ -521,7 +524,7 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
|
244
mlx/ops.cpp
244
mlx/ops.cpp
@ -16,10 +16,9 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool>
|
||||
compute_reduce_shape(
|
||||
std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
bool is_noop = true;
|
||||
std::set<int> axes_set;
|
||||
auto ndim = shape.size();
|
||||
@ -36,8 +35,8 @@ compute_reduce_shape(
|
||||
if (axes_set.size() != axes.size()) {
|
||||
throw std::invalid_argument("Duplicate axes detected in reduction.");
|
||||
}
|
||||
std::vector<int> out_shape;
|
||||
std::vector<int> squeezed_shape;
|
||||
Shape out_shape;
|
||||
Shape squeezed_shape;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
if (axes_set.count(i) == 0) {
|
||||
out_shape.push_back(shape[i]);
|
||||
@ -63,7 +62,7 @@ array indices_or_default(
|
||||
return indices.value();
|
||||
}
|
||||
|
||||
std::vector<int> shape(x.shape().begin(), x.shape().end() - 2);
|
||||
Shape shape(x.shape().begin(), x.shape().end() - 2);
|
||||
int total =
|
||||
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
return reshape(arange(total, uint32, s), shape, s);
|
||||
@ -254,8 +253,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array as_strided(
|
||||
array a,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto copied_shape = shape; // |shape| will be moved
|
||||
@ -279,12 +278,8 @@ array copy(array a, StreamOrDevice s /* = {} */) {
|
||||
{std::move(a)});
|
||||
}
|
||||
|
||||
array full(
|
||||
std::vector<int> shape,
|
||||
array vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) {
|
||||
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
|
||||
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
||||
}
|
||||
auto copied_shape = shape; // |shape| will be moved
|
||||
@ -295,15 +290,12 @@ array full(
|
||||
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
|
||||
}
|
||||
|
||||
array full(std::vector<int> shape, array vals, StreamOrDevice s /* = {} */) {
|
||||
array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = vals.dtype(); // |vals| will be moved
|
||||
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
|
||||
}
|
||||
|
||||
array zeros(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
return full(shape, array(0, dtype), to_stream(s));
|
||||
}
|
||||
|
||||
@ -311,10 +303,7 @@ array zeros_like(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return zeros(a.shape(), a.dtype(), to_stream(s));
|
||||
}
|
||||
|
||||
array ones(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
return full(shape, array(1, dtype), to_stream(s));
|
||||
}
|
||||
|
||||
@ -368,10 +357,7 @@ array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
|
||||
return where(mask, zeros_like(x, s), x, s);
|
||||
}
|
||||
|
||||
array reshape(
|
||||
const array& a,
|
||||
std::vector<int> shape,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
|
||||
if (a.shape() == shape) {
|
||||
return a;
|
||||
}
|
||||
@ -445,11 +431,11 @@ array flatten(
|
||||
if (start_ax == end_ax) {
|
||||
return a;
|
||||
}
|
||||
std::vector<int> new_shape(a.shape().begin(), a.shape().begin() + start_ax);
|
||||
Shape new_shape(a.shape().begin(), a.shape().begin() + start_ax);
|
||||
new_shape.push_back(-1);
|
||||
new_shape.insert(
|
||||
new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end());
|
||||
return reshape(a, new_shape, s);
|
||||
return reshape(a, std::move(new_shape), s);
|
||||
}
|
||||
|
||||
array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
||||
@ -496,7 +482,7 @@ array squeeze(
|
||||
throw std::invalid_argument("[squeeze] Received duplicate axes.");
|
||||
}
|
||||
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
|
||||
std::vector<int> shape;
|
||||
Shape shape;
|
||||
for (int i = 0, j = 0; i < a.ndim(); ++i) {
|
||||
if (j < sorted_axes.size() && i == sorted_axes[j]) {
|
||||
j++;
|
||||
@ -584,12 +570,9 @@ array expand_dims(
|
||||
// Slice helper
|
||||
namespace {
|
||||
|
||||
inline auto normalize_slice(
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int>& start,
|
||||
std::vector<int>& stop,
|
||||
std::vector<int>& strides) {
|
||||
std::vector<int> out_shape(shape.size());
|
||||
inline auto
|
||||
normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
|
||||
Shape out_shape(shape.size());
|
||||
bool has_neg_strides = false;
|
||||
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
@ -641,9 +624,9 @@ inline auto normalize_slice(
|
||||
|
||||
array slice(
|
||||
const array& a,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
std::vector<int> strides,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
Shape strides,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (start.size() != a.ndim() || stop.size() != a.ndim() ||
|
||||
strides.size() != a.ndim()) {
|
||||
@ -670,24 +653,20 @@ array slice(
|
||||
|
||||
array slice(
|
||||
const array& a,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return slice(
|
||||
a,
|
||||
std::move(start),
|
||||
std::move(stop),
|
||||
std::vector<int>(a.ndim(), 1),
|
||||
to_stream(s));
|
||||
a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s));
|
||||
}
|
||||
|
||||
/** Update a slice from the source array */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
std::vector<int> strides,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
Shape strides,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Check dimensions
|
||||
if (start.size() != src.ndim() || stop.size() != src.ndim() ||
|
||||
@ -721,12 +700,11 @@ array slice_update(
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto strides = std::vector<int>(src.ndim(), 1);
|
||||
return slice_update(
|
||||
src, update, std::move(start), std::move(stop), std::move(strides), s);
|
||||
src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s);
|
||||
}
|
||||
|
||||
std::vector<array> split(
|
||||
@ -750,7 +728,7 @@ std::vector<array> split(
|
||||
std::is_sorted(indices.begin(), indices.end(), std::less<>{}) &&
|
||||
indices[0] > 0 && indices.back() < a.shape(ax)) {
|
||||
std::vector<Dtype> dtypes(indices.size() + 1, a.dtype());
|
||||
std::vector<std::vector<int>> shapes(indices.size() + 1, a.shape());
|
||||
std::vector<Shape> shapes(indices.size() + 1, a.shape());
|
||||
shapes[0][ax] = indices[0];
|
||||
for (int i = 1; i < indices.size(); i++) {
|
||||
shapes[i][ax] = indices[i] - indices[i - 1];
|
||||
@ -765,8 +743,7 @@ std::vector<array> split(
|
||||
}
|
||||
|
||||
std::vector<array> res;
|
||||
auto out_shape = a.shape();
|
||||
auto start_indices = std::vector<int>(a.ndim(), 0);
|
||||
auto start_indices = Shape(a.ndim(), 0);
|
||||
auto stop_indices = a.shape();
|
||||
for (int i = 0; i < indices.size() + 1; ++i) {
|
||||
stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax);
|
||||
@ -826,13 +803,13 @@ std::vector<array> meshgrid(
|
||||
auto ndim = arrays.size();
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
std::vector<int> shape(ndim, 1);
|
||||
Shape shape(ndim, 1);
|
||||
shape[i] = -1;
|
||||
outputs.push_back(reshape(arrays[i], std::move(shape), s));
|
||||
}
|
||||
|
||||
if (indexing == "xy" and ndim > 1) {
|
||||
std::vector<int> shape(ndim, 1);
|
||||
Shape shape(ndim, 1);
|
||||
|
||||
shape[1] = arrays[0].size();
|
||||
outputs[0] = reshape(arrays[0], shape, s);
|
||||
@ -895,7 +872,7 @@ array concatenate(
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
|
||||
std::vector<int> shape = arrays[0].shape();
|
||||
auto shape = arrays[0].shape();
|
||||
shape[ax] = 0;
|
||||
// Make the output shape and validate that all arrays have the same shape
|
||||
// except for the concatenation axis.
|
||||
@ -980,7 +957,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
|
||||
}
|
||||
|
||||
// Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)
|
||||
std::vector<int> shape(arr.shape());
|
||||
auto shape = arr.shape();
|
||||
shape.insert(shape.begin() + axis + 1, repeats);
|
||||
array out = expand_dims(arr, axis + 1, s);
|
||||
out = broadcast_to(out, shape, s);
|
||||
@ -1009,9 +986,9 @@ array tile(
|
||||
shape.insert(shape.begin(), reps.size() - shape.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int> expand_shape;
|
||||
std::vector<int> broad_shape;
|
||||
std::vector<int> final_shape;
|
||||
Shape expand_shape;
|
||||
Shape broad_shape;
|
||||
Shape final_shape;
|
||||
for (int i = 0; i < shape.size(); i++) {
|
||||
if (reps[i] != 1) {
|
||||
expand_shape.push_back(1);
|
||||
@ -1022,17 +999,17 @@ array tile(
|
||||
final_shape.push_back(reps[i] * shape[i]);
|
||||
}
|
||||
|
||||
auto x = reshape(arr, expand_shape, s);
|
||||
x = broadcast_to(x, broad_shape, s);
|
||||
return reshape(x, final_shape, s);
|
||||
auto x = reshape(arr, std::move(expand_shape), s);
|
||||
x = broadcast_to(x, std::move(broad_shape), s);
|
||||
return reshape(x, std::move(final_shape), s);
|
||||
}
|
||||
|
||||
array edge_pad(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& low_pad_size,
|
||||
const std::vector<int>& high_pad_size,
|
||||
const std::vector<int>& out_shape,
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const Shape& out_shape,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
array out = zeros(out_shape, a.dtype(), s);
|
||||
auto stops = a.shape();
|
||||
@ -1044,7 +1021,7 @@ array edge_pad(
|
||||
|
||||
for (int axis = 0; axis < a.ndim(); axis++) {
|
||||
if (low_pad_size[axis] > 0) {
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
Shape starts(a.ndim(), 0);
|
||||
starts[axis] = low_pad_size[axis];
|
||||
auto stops = out.shape();
|
||||
stops[axis] = low_pad_size[axis] + 1;
|
||||
@ -1058,7 +1035,7 @@ array edge_pad(
|
||||
}
|
||||
|
||||
if (high_pad_size[axis] > 0) {
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
Shape starts(a.ndim(), 0);
|
||||
starts[axis] = -high_pad_size[axis] - 1;
|
||||
auto stops = out.shape();
|
||||
stops[axis] = -high_pad_size[axis];
|
||||
@ -1075,9 +1052,9 @@ array edge_pad(
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& low_pad_size,
|
||||
const std::vector<int>& high_pad_size,
|
||||
const Shape& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
@ -1089,7 +1066,7 @@ array pad(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::vector<int> out_shape = a.shape();
|
||||
auto out_shape = a.shape();
|
||||
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
if (low_pad_size[i] < 0) {
|
||||
@ -1113,7 +1090,7 @@ array pad(
|
||||
|
||||
if (mode == "constant") {
|
||||
return array(
|
||||
out_shape,
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
||||
{a, astype(pad_value, a.dtype(), s)});
|
||||
@ -1136,8 +1113,8 @@ array pad(
|
||||
std::vector<int> axes(a.ndim(), 0);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
|
||||
std::vector<int> lows;
|
||||
std::vector<int> highs;
|
||||
Shape lows;
|
||||
Shape highs;
|
||||
|
||||
for (auto& pads : pad_width) {
|
||||
lows.push_back(pads.first);
|
||||
@ -1240,7 +1217,7 @@ array transpose(
|
||||
}
|
||||
|
||||
// Check in bounds and for duplicates
|
||||
std::vector<int> shape(axes.size(), 0);
|
||||
Shape shape(axes.size(), 0);
|
||||
for (auto& ax : axes) {
|
||||
if (ax < 0 || ax >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
@ -1272,7 +1249,7 @@ array transpose(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array broadcast_to(
|
||||
const array& a,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (a.shape() == shape) {
|
||||
return a;
|
||||
@ -1295,14 +1272,14 @@ array broadcast_to(
|
||||
|
||||
std::vector<array>
|
||||
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
std::vector<int> shape = broadcast_shapes(a.shape(), b.shape());
|
||||
auto shape = broadcast_shapes(a.shape(), b.shape());
|
||||
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
|
||||
}
|
||||
|
||||
std::vector<array> broadcast_arrays(
|
||||
const std::vector<array>& inputs,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<int> shape{};
|
||||
Shape shape{};
|
||||
for (const auto& in : inputs) {
|
||||
shape = broadcast_shapes(shape, in.shape());
|
||||
}
|
||||
@ -1913,7 +1890,7 @@ array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
|
||||
int size = a.size();
|
||||
auto result = argmax(reshape(a, {size}, s), 0, true, s);
|
||||
if (keepdims) {
|
||||
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
|
||||
result = reshape(result, Shape(a.shape().size(), 1), s);
|
||||
} else {
|
||||
result = squeeze(result, s);
|
||||
}
|
||||
@ -2098,8 +2075,8 @@ array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
|
||||
}
|
||||
|
||||
array a_partitioned = partition(a, -k, axis_, s);
|
||||
std::vector<int> slice_starts(a.ndim(), 0);
|
||||
std::vector<int> slice_ends = a.shape();
|
||||
Shape slice_starts(a.ndim(), 0);
|
||||
auto slice_ends = a.shape();
|
||||
slice_starts[axis_] = a.shape(axis_) - k;
|
||||
return slice(a_partitioned, slice_starts, slice_ends, s);
|
||||
}
|
||||
@ -2613,8 +2590,8 @@ array matmul(
|
||||
}
|
||||
|
||||
if (a.ndim() > 2 || b.ndim() > 2) {
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
|
||||
// Broadcast a
|
||||
@ -2648,7 +2625,7 @@ array gather(
|
||||
const array& a,
|
||||
const std::vector<array>& indices,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes,
|
||||
const Shape& slice_sizes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Checks that indices, dimensions, and slice_sizes are all valid
|
||||
if (indices.size() > a.ndim()) {
|
||||
@ -2703,7 +2680,7 @@ array gather(
|
||||
idx = astype(idx, dtype, s);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
Shape out_shape;
|
||||
if (!inputs.empty()) {
|
||||
out_shape = inputs[0].shape();
|
||||
}
|
||||
@ -2741,7 +2718,7 @@ array take(
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
// Make slice sizes to pass to gather
|
||||
std::vector<int> slice_sizes = a.shape();
|
||||
Shape slice_sizes = a.shape();
|
||||
slice_sizes[axis] = indices.size() > 0 ? 1 : 0;
|
||||
|
||||
auto out = gather(a, indices, axis, slice_sizes, s);
|
||||
@ -2759,7 +2736,7 @@ array take(
|
||||
}
|
||||
|
||||
// Squeeze the axis we take over
|
||||
std::vector<int> out_shape = out.shape();
|
||||
auto out_shape = out.shape();
|
||||
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
|
||||
return reshape(out, std::move(out_shape), s);
|
||||
}
|
||||
@ -2787,8 +2764,8 @@ array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
|
||||
// Handle negative axis
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
std::vector<int> stops = a.shape();
|
||||
Shape starts(a.ndim(), 0);
|
||||
Shape stops = a.shape();
|
||||
starts[axis] = index;
|
||||
stops[axis] = index + 1;
|
||||
return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s);
|
||||
@ -2821,7 +2798,7 @@ array take_along_axis(
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<array> nd_indices;
|
||||
std::vector<int> index_shape(a.ndim(), 1);
|
||||
Shape index_shape(a.ndim(), 1);
|
||||
for (int i = 0; i < a.ndim(); ++i) {
|
||||
if (i == axis) {
|
||||
nd_indices.push_back(indices);
|
||||
@ -2834,12 +2811,11 @@ array take_along_axis(
|
||||
}
|
||||
std::vector<int> dims(a.ndim());
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
std::vector<int> slice_sizes(a.ndim(), a.size() > 0);
|
||||
Shape slice_sizes(a.ndim(), a.size() > 0);
|
||||
auto out = gather(a, nd_indices, dims, slice_sizes, s);
|
||||
|
||||
// Squeeze out the slice shape
|
||||
std::vector<int> out_shape(
|
||||
out.shape().begin(), out.shape().begin() + a.ndim());
|
||||
Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim());
|
||||
return reshape(out, std::move(out_shape), s);
|
||||
}
|
||||
|
||||
@ -2867,7 +2843,7 @@ array put_along_axis(
|
||||
axis = axis < 0 ? a.ndim() + axis : axis;
|
||||
|
||||
std::vector<array> nd_indices;
|
||||
std::vector<int> index_shape(a.ndim(), 1);
|
||||
Shape index_shape(a.ndim(), 1);
|
||||
for (int i = 0; i < a.ndim(); ++i) {
|
||||
if (i == axis) {
|
||||
nd_indices.push_back(indices);
|
||||
@ -2927,7 +2903,7 @@ array scatter(
|
||||
// Broadcast and cast indices if necessary
|
||||
auto inputs = broadcast_arrays(indices);
|
||||
|
||||
std::vector<int> idx_shape;
|
||||
Shape idx_shape;
|
||||
if (!inputs.empty()) {
|
||||
idx_shape = inputs[0].shape();
|
||||
}
|
||||
@ -3198,7 +3174,7 @@ inline int dilate_size(int dim, int dil) {
|
||||
return 1 + dil * (dim - 1);
|
||||
}
|
||||
|
||||
inline std::vector<int> conv_out_shape(
|
||||
Shape conv_out_shape(
|
||||
const std::vector<int>& in_shape,
|
||||
const std::vector<int>& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
@ -3208,7 +3184,7 @@ inline std::vector<int> conv_out_shape(
|
||||
const std::vector<int>& input_dilation) {
|
||||
int N = in_shape[0];
|
||||
int O = wt_shape[0];
|
||||
std::vector<int> out_shape(in_shape.size());
|
||||
Shape out_shape(in_shape.size());
|
||||
int i = 0;
|
||||
out_shape[i++] = N;
|
||||
|
||||
@ -3577,8 +3553,8 @@ array conv_general(
|
||||
|
||||
// Handle negative padding
|
||||
if (has_neg_padding) {
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> stops = in.shape();
|
||||
Shape starts(in.ndim(), 0);
|
||||
auto stops = in.shape();
|
||||
|
||||
for (int i = 0; i < spatial_dims; i++) {
|
||||
if (padding_lo[i] < 0) {
|
||||
@ -3596,7 +3572,7 @@ array conv_general(
|
||||
}
|
||||
|
||||
// Get output shapes
|
||||
std::vector<int> out_shape = conv_out_shape(
|
||||
auto out_shape = conv_out_shape(
|
||||
in.shape(),
|
||||
wt.shape(),
|
||||
stride,
|
||||
@ -3606,7 +3582,7 @@ array conv_general(
|
||||
input_dilation);
|
||||
|
||||
return array(
|
||||
out_shape,
|
||||
std::move(out_shape),
|
||||
in.dtype(),
|
||||
std::make_shared<Convolution>(
|
||||
to_stream(s),
|
||||
@ -3634,8 +3610,8 @@ array quantized_matmul(
|
||||
|
||||
// QuantizedMatmul handles w.ndim == 2 case.
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
std::vector<int> bsx_x(x.shape().begin(), x.shape().end() - 2);
|
||||
std::vector<int> bsx_w(w.shape().begin(), w.shape().end() - 2);
|
||||
Shape bsx_x(x.shape().begin(), x.shape().end() - 2);
|
||||
Shape bsx_w(w.shape().begin(), w.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
|
||||
|
||||
// Broadcast x
|
||||
@ -3731,7 +3707,7 @@ array gather_qmm(
|
||||
// and output type
|
||||
auto out_type = result_type(x, scales, biases);
|
||||
|
||||
auto out = array(
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
||||
@ -3741,8 +3717,6 @@ array gather_qmm(
|
||||
astype(biases, out_type, s),
|
||||
lhs_indices,
|
||||
rhs_indices});
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
array tensordot(
|
||||
@ -3802,7 +3776,7 @@ array tensordot(
|
||||
|
||||
std::vector<int> t1;
|
||||
std::vector<int> t2;
|
||||
std::vector<int> rshape;
|
||||
Shape rshape;
|
||||
int size1 = 1;
|
||||
int size2 = 1;
|
||||
for (int i = 0; i < a.ndim(); i++) {
|
||||
@ -3898,7 +3872,7 @@ array addmm(
|
||||
|
||||
// We can batch the multiplication by reshaping a
|
||||
if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) {
|
||||
std::vector<int> out_shape = a.shape();
|
||||
auto out_shape = a.shape();
|
||||
a = reshape(a, {-1, out_shape.back()}, s);
|
||||
out_shape.back() = b.shape(-1);
|
||||
|
||||
@ -3917,8 +3891,8 @@ array addmm(
|
||||
}
|
||||
|
||||
if (a.ndim() > 2 || b.ndim() > 2) {
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
|
||||
// Broadcast a
|
||||
@ -4042,8 +4016,8 @@ array block_masked_mm(
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
// Handle broadcasting
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
|
||||
auto bsx_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
|
||||
@ -4079,7 +4053,7 @@ array block_masked_mm(
|
||||
|
||||
// Broadcast and astype mask
|
||||
auto broadcast_mask = [](array mask,
|
||||
std::vector<int>& bs_shape,
|
||||
Shape& bs_shape,
|
||||
int y,
|
||||
int x,
|
||||
Dtype mask_dtype,
|
||||
@ -4397,7 +4371,7 @@ std::vector<array> depends(
|
||||
Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream()
|
||||
: to_stream({});
|
||||
// Make the output info
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (const auto& in : inputs) {
|
||||
shapes.emplace_back(in.shape());
|
||||
@ -4434,7 +4408,7 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
case 0:
|
||||
return reshape(a, {1, 1}, s);
|
||||
case 1:
|
||||
return reshape(a, {1, static_cast<int>(a.size())}, s);
|
||||
return reshape(a, {1, a.shape(0)}, s);
|
||||
default:
|
||||
return a;
|
||||
}
|
||||
@ -4456,7 +4430,7 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
case 0:
|
||||
return reshape(a, {1, 1, 1}, s);
|
||||
case 1:
|
||||
return reshape(a, {1, static_cast<int>(a.size()), 1}, s);
|
||||
return reshape(a, {1, a.shape(0), 1}, s);
|
||||
case 2:
|
||||
return reshape(a, {a.shape(0), a.shape(1), 1}, s);
|
||||
default:
|
||||
@ -4493,7 +4467,7 @@ array number_of_elements(
|
||||
}
|
||||
|
||||
return stop_gradient(array(
|
||||
std::vector<int>{},
|
||||
Shape{},
|
||||
dtype,
|
||||
std::make_shared<NumberOfElements>(
|
||||
to_stream(s), std::move(axes), inverted, dtype),
|
||||
@ -4613,7 +4587,7 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
const Shape& shift,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axes.empty()) {
|
||||
@ -4627,7 +4601,6 @@ array roll(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::vector<array> parts;
|
||||
array result = a;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
int ax = axes[i];
|
||||
@ -4641,11 +4614,11 @@ array roll(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int sh = shift[i];
|
||||
int split_index =
|
||||
auto sh = shift[i];
|
||||
auto split_index =
|
||||
(sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax);
|
||||
|
||||
parts = split(result, std::vector<int>{split_index}, ax, s);
|
||||
auto parts = split(result, Shape{split_index}, ax, s);
|
||||
std::swap(parts[0], parts[1]);
|
||||
result = concatenate(parts, ax, s);
|
||||
}
|
||||
@ -4656,19 +4629,12 @@ array roll(
|
||||
array roll(const array& a, int shift, StreamOrDevice s /* = {} */) {
|
||||
auto shape = a.shape();
|
||||
return reshape(
|
||||
roll(
|
||||
reshape(a, std::vector<int>{-1}, s),
|
||||
std::vector<int>{shift},
|
||||
std::vector<int>{0},
|
||||
s),
|
||||
roll(reshape(a, Shape{-1}, s), Shape{shift}, std::vector<int>{0}, s),
|
||||
std::move(shape),
|
||||
s);
|
||||
}
|
||||
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array roll(const array& a, const Shape& shift, StreamOrDevice s /* = {} */) {
|
||||
int total_shift = 0;
|
||||
for (auto& s : shift) {
|
||||
total_shift += s;
|
||||
@ -4677,7 +4643,7 @@ array roll(
|
||||
}
|
||||
|
||||
array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) {
|
||||
return roll(a, std::vector<int>{shift}, std::vector<int>{axis}, s);
|
||||
return roll(a, Shape{shift}, std::vector<int>{axis}, s);
|
||||
}
|
||||
|
||||
array roll(
|
||||
@ -4685,20 +4651,20 @@ array roll(
|
||||
int shift,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<int> shifts(axes.size(), shift);
|
||||
Shape shifts(axes.size(), shift);
|
||||
return roll(a, shifts, axes, s);
|
||||
}
|
||||
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
const Shape& shift,
|
||||
int axis,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
int total_shift = 0;
|
||||
for (auto& s : shift) {
|
||||
total_shift += s;
|
||||
}
|
||||
return roll(a, std::vector<int>{total_shift}, std::vector<int>{axis}, s);
|
||||
return roll(a, Shape{total_shift}, std::vector<int>{axis}, s);
|
||||
}
|
||||
|
||||
array real(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
76
mlx/ops.h
76
mlx/ops.h
@ -49,8 +49,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s = {});
|
||||
/** Create a view of an array with the given shape and strides. */
|
||||
array as_strided(
|
||||
array a,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@ -58,31 +58,27 @@ array as_strided(
|
||||
array copy(array a, StreamOrDevice s = {});
|
||||
|
||||
/** Fill an array of the given shape with the given value(s). */
|
||||
array full(
|
||||
std::vector<int> shape,
|
||||
array vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {});
|
||||
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
|
||||
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
|
||||
array full(Shape shape, array vals, StreamOrDevice s = {});
|
||||
template <typename T>
|
||||
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||
array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val, dtype), to_stream(s));
|
||||
}
|
||||
template <typename T>
|
||||
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
|
||||
array full(Shape shape, T val, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val), to_stream(s));
|
||||
}
|
||||
|
||||
/** Fill an array of the given shape with zeros. */
|
||||
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
||||
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
|
||||
return zeros(shape, float32, s);
|
||||
}
|
||||
array zeros_like(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Fill an array of the given shape with ones. */
|
||||
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
||||
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array ones(const Shape& shape, StreamOrDevice s = {}) {
|
||||
return ones(shape, float32, s);
|
||||
}
|
||||
array ones_like(const array& a, StreamOrDevice s = {});
|
||||
@ -119,7 +115,7 @@ array tril(array x, int k = 0, StreamOrDevice s = {});
|
||||
array triu(array x, int k = 0, StreamOrDevice s = {});
|
||||
|
||||
/** Reshape an array to the given shape. */
|
||||
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
||||
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
|
||||
|
||||
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
||||
array flatten(
|
||||
@ -161,33 +157,29 @@ array expand_dims(const array& a, int axis, StreamOrDevice s = {});
|
||||
/** Slice an array. */
|
||||
array slice(
|
||||
const array& a,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
std::vector<int> strides,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
Shape strides,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Slice an array with a stride of 1 in each dimension. */
|
||||
array slice(
|
||||
const array& a,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
StreamOrDevice s = {});
|
||||
array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
std::vector<int> strides,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
Shape strides,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array with stride 1 in each dimension */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
Shape start,
|
||||
Shape stop,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Split an array into sub-arrays along a given axis. */
|
||||
@ -288,10 +280,7 @@ array pad(
|
||||
array transpose(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Broadcast an array to a given shape. */
|
||||
array broadcast_to(
|
||||
const array& a,
|
||||
const std::vector<int>& shape,
|
||||
StreamOrDevice s = {});
|
||||
array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});
|
||||
|
||||
/** Broadcast a vector of arrays against one another. */
|
||||
std::vector<array> broadcast_arrays(
|
||||
@ -917,13 +906,13 @@ array gather(
|
||||
const array& a,
|
||||
const std::vector<array>& indices,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes,
|
||||
const Shape& slice_sizes,
|
||||
StreamOrDevice s = {});
|
||||
inline array gather(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
int axis,
|
||||
const std::vector<int>& slice_sizes,
|
||||
const Shape& slice_sizes,
|
||||
StreamOrDevice s = {}) {
|
||||
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
||||
}
|
||||
@ -1459,24 +1448,13 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
|
||||
|
||||
/** Roll elements along an axis and introduce them on the other side */
|
||||
array roll(const array& a, int shift, StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
StreamOrDevice s = {});
|
||||
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
|
||||
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
|
||||
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
|
||||
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
int shift,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
int axis,
|
||||
StreamOrDevice s = {});
|
||||
array roll(
|
||||
const array& a,
|
||||
const std::vector<int>& shift,
|
||||
const Shape& shift,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
|
@ -66,9 +66,7 @@ Dtype result_type(const std::vector<array>& arrays) {
|
||||
return t;
|
||||
}
|
||||
|
||||
std::vector<int> broadcast_shapes(
|
||||
const std::vector<int>& s1,
|
||||
const std::vector<int>& s2) {
|
||||
Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
||||
// Use the same broadcasting rules as numpy
|
||||
// https://numpy.org/doc/1.20/user/theory.broadcasting.html
|
||||
// "The size of the trailing axes for both arrays in an operation must
|
||||
@ -79,7 +77,7 @@ std::vector<int> broadcast_shapes(
|
||||
int diff = std::abs(ndim1 - ndim2);
|
||||
const auto& big = ndim1 > ndim2 ? s1 : s2;
|
||||
const auto& small = ndim1 > ndim2 ? s2 : s1;
|
||||
std::vector<int> out_shape(ndim);
|
||||
Shape out_shape(ndim);
|
||||
for (int i = ndim - 1; i >= diff; --i) {
|
||||
int a = big[i];
|
||||
int b = small[i - diff];
|
||||
@ -158,10 +156,8 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
|
||||
|
||||
namespace {
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
inline size_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
size_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
@ -199,7 +195,6 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
|
||||
|
||||
template <typename T>
|
||||
void print_array(std::ostream& os, const array& a) {
|
||||
std::vector<int> indices(a.ndim(), 0);
|
||||
os << std::boolalpha;
|
||||
os << "array(";
|
||||
if (a.ndim() == 0) {
|
||||
@ -310,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||
std::ostream& operator<<(std::ostream& os, const Shape& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
@ -319,7 +314,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
|
||||
std::ostream& operator<<(std::ostream& os, const Strides& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
|
@ -62,9 +62,7 @@ inline Dtype result_type(const array& a, const array& b, const array& c) {
|
||||
}
|
||||
Dtype result_type(const std::vector<array>& arrays);
|
||||
|
||||
std::vector<int> broadcast_shapes(
|
||||
const std::vector<int>& s1,
|
||||
const std::vector<int>& s2);
|
||||
Shape broadcast_shapes(const Shape& s1, const Shape& s2);
|
||||
|
||||
bool is_same_shape(const std::vector<array>& arrays);
|
||||
|
||||
@ -96,8 +94,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||
std::ostream& operator<<(std::ostream& os, array a);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const Shape& v);
|
||||
std::ostream& operator<<(std::ostream& os, const Strides& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
||||
|
Loading…
Reference in New Issue
Block a user