// Copyright © 2023-2024 Apple Inc. #include #include #include #include #include #include #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/utils.h" namespace mlx::core { namespace { std::pair, std::vector> compute_reduce_shape( const std::vector& axes, const std::vector& shape) { std::set axes_set; auto ndim = shape.size(); for (auto ax : axes) { int ax_ = (ax < 0) ? ax + ndim : ax; if (ax_ < 0 || ax_ >= ndim) { std::ostringstream msg; msg << "Invalid axis " << ax << " for array with " << ndim << " dimensions."; throw std::out_of_range(msg.str()); } axes_set.insert(ax_); } if (axes_set.size() != axes.size()) { throw std::invalid_argument("Duplicate axes detected in reduction."); } std::vector out_shape; for (int i = 0; i < ndim; ++i) { if (axes_set.count(i) == 0) { out_shape.push_back(shape[i]); } else { out_shape.push_back(1); } } std::vector sorted_axes(axes_set.begin(), axes_set.end()); return {out_shape, sorted_axes}; } Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } array indices_or_default( std::optional indices, const array& x, StreamOrDevice s) { if (indices.has_value()) { return indices.value(); } std::vector shape(x.shape().begin(), x.shape().end() - 2); int total = std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); return reshape(arange(total, uint32, s), shape, s); } std::pair extract_quantized_matmul_dims( std::string_view tag, const array& x, const array& w, const array& scales, const array& biases, bool transpose, int group_size, int bits) { if (w.dtype() != uint32) { std::ostringstream msg; msg << "[" << tag << "] The weight matrix should be uint32 " << "but received " << w.dtype(); throw std::invalid_argument(msg.str()); } if (scales.shape() != biases.shape()) { std::ostringstream msg; msg << "[" << tag << "] Scales and biases should have the same shape. " << "Received scales with shape " << scales.shape() << " and biases with " << biases.shape(); throw std::invalid_argument(msg.str()); } if (!std::equal( w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) { std::ostringstream msg; msg << "[" << tag << "] Weight, scales and biases should have the same batch shape. " << "Received weight with shape " << w.shape() << ", scales with " << scales.shape() << " and biases with " << biases.shape(); throw std::invalid_argument(msg.str()); } if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { std::ostringstream msg; msg << "[" << tag << "] The shapes of the weight and scales are " << "incompatible based on bits and group_size. w.shape() == " << w.shape() << " and scales.shape() == " << scales.shape() << " with group_size=" << group_size << " and bits=" << bits; throw std::invalid_argument(msg.str()); } int x_inner_dims = x.shape(-1); // Calculate the expanded w's dims int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2); int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits; if (w_inner_dims != x_inner_dims) { std::ostringstream msg; msg << "[" << tag << "] Last dimension of first input with " << "shape (..., " << x_inner_dims << ") does not match " << "the expanded quantized matrix (" << w_inner_dims << ", " << w_outer_dims << ") computed from shape " << w.shape() << " with group_size=" << group_size << ", bits=" << bits << " and transpose=" << std::boolalpha << transpose; throw std::invalid_argument(msg.str()); } return {w_inner_dims, w_outer_dims}; } } // namespace array arange( double start, double stop, double step, Dtype dtype, StreamOrDevice s /* = {} */) { if (dtype == bool_) { std::ostringstream msg; msg << bool_ << " not supported for arange."; throw std::invalid_argument(msg.str()); } if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) { throw std::invalid_argument("[arange] Cannot compute length."); } if (std::isinf(start) || std::isinf(stop)) { throw std::invalid_argument("[arange] Cannot compute length."); } // Check if start and stop specify a valid range because if not, we have to // return an empty array if (std::isinf(step) && (step > 0 && start < stop || step < 0 && start > stop)) { return array({start}, dtype); } double real_size = std::ceil((stop - start) / step); if (real_size > INT_MAX) { throw std::invalid_argument("[arange] Maximum size exceeded."); } int size = std::max(static_cast(real_size), 0); return array( {size}, dtype, std::make_shared(to_stream(s), start, stop, step), {}); } array arange( double start, double stop, double step, StreamOrDevice s /* = {} */) { return arange(start, stop, step, float32, to_stream(s)); } array arange( double start, double stop, Dtype dtype, StreamOrDevice s /* = {} */) { return arange(start, stop, 1.0, dtype, to_stream(s)); } array arange(double start, double stop, StreamOrDevice s /* = {} */) { return arange(start, stop, 1.0, float32, to_stream(s)); } array arange(double stop, Dtype dtype, StreamOrDevice s /* = {} */) { return arange(0.0, stop, 1.0, dtype, to_stream(s)); } array arange(double stop, StreamOrDevice s /* = {} */) { return arange(0.0, stop, 1.0, float32, to_stream(s)); } array arange(int start, int stop, int step, StreamOrDevice s /* = {} */) { return arange( static_cast(start), static_cast(stop), static_cast(step), int32, to_stream(s)); } array arange(int start, int stop, StreamOrDevice s /* = {} */) { return arange( static_cast(start), static_cast(stop), 1.0, int32, to_stream(s)); } array arange(int stop, StreamOrDevice s /* = {} */) { return arange(0.0, static_cast(stop), 1.0, int32, to_stream(s)); } array linspace( double start, double stop, int num /* = 50 */, Dtype dtype /* = float32 */, StreamOrDevice s /* = {} */) { if (num < 0) { std::ostringstream msg; msg << "[linspace] number of samples, " << num << ", must be non-negative."; throw std::invalid_argument(msg.str()); } if (num == 1) { return astype(array({start}), dtype, to_stream(s)); } array sequence = arange(0, num, float32, to_stream(s)); float step = (stop - start) / (num - 1); return astype( add(multiply(sequence, array(step), to_stream(s)), array(start), to_stream(s)), dtype, to_stream(s)); } array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { if (dtype == a.dtype()) { return std::move(a); } auto copied_shape = a.shape(); // |a| will be moved return array( std::move(copied_shape), dtype, std::make_shared(to_stream(s), dtype), {std::move(a)}); } array as_strided( array a, std::vector shape, std::vector strides, size_t offset, StreamOrDevice s /* = {} */) { auto copied_shape = shape; // |shape| will be moved auto dtype = a.dtype(); // |a| will be moved return array( std::move(copied_shape), dtype, std::make_shared( to_stream(s), std::move(shape), std::move(strides), offset), // Force the input array to be contiguous. {reshape(std::move(a), {-1}, s)}); } array copy(array a, StreamOrDevice s /* = {} */) { auto copied_shape = a.shape(); // |a| will be moved auto dtype = a.dtype(); return array( std::move(copied_shape), dtype, std::make_shared(to_stream(s)), {std::move(a)}); } array full( std::vector shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) { if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) { throw std::invalid_argument("[full] Negative dimensions not allowed."); } auto copied_shape = shape; // |shape| will be moved return array( std::move(copied_shape), dtype, std::make_shared(to_stream(s)), {broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)}); } array full(std::vector shape, array vals, StreamOrDevice s /* = {} */) { auto dtype = vals.dtype(); // |vals| will be moved return full(std::move(shape), std::move(vals), dtype, to_stream(s)); } array zeros( const std::vector& shape, Dtype dtype, StreamOrDevice s /* = {} */) { return full(shape, array(0, dtype), to_stream(s)); } array zeros_like(const array& a, StreamOrDevice s /* = {} */) { return zeros(a.shape(), a.dtype(), to_stream(s)); } array ones( const std::vector& shape, Dtype dtype, StreamOrDevice s /* = {} */) { return full(shape, array(1, dtype), to_stream(s)); } array ones_like(const array& a, StreamOrDevice s /* = {} */) { return ones(a.shape(), a.dtype(), to_stream(s)); } array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) { if (n <= 0 || m <= 0) { throw std::invalid_argument("[eye] N and M must be positive integers."); } array result = zeros({n, m}, dtype, s); if (k >= m || -k >= n) { return result; } int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m); std::vector indices; auto s1 = std::max(0, -k); auto s2 = std::max(0, k); indices.push_back(arange(s1, diagonal_length + s1, int32, s)); indices.push_back(arange(s2, diagonal_length + s2, int32, s)); array ones_array = ones({diagonal_length, 1, 1}, dtype, s); return scatter(result, indices, ones_array, {0, 1}, s); } array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) { return eye(n, n, 0, dtype, s); } array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) { auto l = expand_dims(arange(n, s), 1, s); auto r = expand_dims(arange(-k, m - k, s), 0, s); return astype(greater_equal(l, r, s), type, s); } array tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { if (x.ndim() < 2) { throw std::invalid_argument("[tril] array must be at least 2-D"); } auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s); return where(mask, x, zeros_like(x, s), s); } array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { if (x.ndim() < 2) { throw std::invalid_argument("[triu] array must be at least 2-D"); } auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s); return where(mask, zeros_like(x, s), x, s); } array reshape( const array& a, std::vector shape, StreamOrDevice s /* = {} */) { if (a.shape() == shape) { return a; } size_t size = 1; int infer_idx = -1; for (int i = 0; i < shape.size(); ++i) { if (shape[i] == -1) { if (infer_idx >= 0) { throw std::invalid_argument( "[reshape] Reshape can only infer one dimension."); } infer_idx = i; } else { size *= shape[i]; } } // Infer the shape if (size > 0) { auto q_and_r = std::ldiv(a.size(), size); if (infer_idx >= 0) { shape[infer_idx] = q_and_r.quot; size *= q_and_r.quot; } } else if (infer_idx >= 0) { throw std::invalid_argument( "[reshape] Cannot infer the shape of an empty array"); } // Check that the reshaping is valid if (a.size() != size) { std::ostringstream msg; msg << "[reshape] Cannot reshape array of size " << a.size() << " into shape " << shape << "."; throw std::invalid_argument(msg.str()); } auto p = std::make_shared(to_stream(s), shape); return array(std::move(shape), a.dtype(), std::move(p), {a}); } array flatten( const array& a, int start_axis, int end_axis /* = -1 */, StreamOrDevice s /* = {} */) { auto ndim = static_cast(a.ndim()); auto start_ax = start_axis + (start_axis < 0 ? ndim : 0); auto end_ax = end_axis + (end_axis < 0 ? ndim : 0); start_ax = std::max(0, start_ax); end_ax = std::min(ndim - 1, end_ax); if (a.ndim() == 0) { return reshape(a, {1}, s); } if (end_ax < start_ax) { throw std::invalid_argument( "[flatten] start_axis must be less than or equal to end_axis"); } if (start_ax >= ndim) { std::ostringstream msg; msg << "[flatten] Invalid start_axis " << start_axis << " for array with " << ndim << " dimensions."; throw std::invalid_argument(msg.str()); } if (end_ax < 0) { std::ostringstream msg; msg << "[flatten] Invalid end_axis " << end_axis << " for array with " << ndim << " dimensions."; throw std::invalid_argument(msg.str()); } if (start_ax == end_ax) { return a; } std::vector new_shape(a.shape().begin(), a.shape().begin() + start_ax); new_shape.push_back(-1); new_shape.insert( new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end()); return reshape(a, new_shape, s); } array flatten(const array& a, StreamOrDevice s /* = {} */) { return flatten(a, 0, a.ndim() - 1, s); } array hadamard_transform( const array& a, std::optional scale_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { // Default to an orthonormal Hadamard matrix scaled by 1/sqrt(N) float scale = scale_.has_value() ? *scale_ : 1.0f / std::sqrt(a.shape(-1)); auto dtype = issubdtype(a.dtype(), floating) ? a.dtype() : float32; return array( a.shape(), dtype, std::make_shared(to_stream(s), scale), {astype(a, dtype, s)}); } array squeeze( const array& a, const std::vector& axes, StreamOrDevice s /* = {} */) { std::set unique_axes; for (auto ax : axes) { ax = ax < 0 ? ax + a.ndim() : ax; if (ax < 0 || ax >= a.ndim()) { std::ostringstream msg; msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } if (a.shape(ax) != 1) { std::ostringstream msg; msg << "[squeeze] Cannot squeeze axis " << ax << " with size " << a.shape(ax) << " which is not equal to 1."; throw std::invalid_argument(msg.str()); } unique_axes.insert(ax); } if (unique_axes.size() != axes.size()) { throw std::invalid_argument("[squeeze] Received duplicate axes."); } std::vector sorted_axes(unique_axes.begin(), unique_axes.end()); std::vector shape; for (int i = 0, j = 0; i < a.ndim(); ++i) { if (j < sorted_axes.size() && i == sorted_axes[j]) { j++; } else { shape.push_back(a.shape(i)); } } return reshape(a, shape, s); } array squeeze(const array& a, StreamOrDevice s /* = {} */) { std::vector axes; for (int i = 0; i < a.ndim(); ++i) { if (a.shape(i) == 1) { axes.push_back(i); } } return squeeze(a, axes, s); } array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) { int out_dim = a.ndim() + 1; int ax = axis < 0 ? axis + out_dim : axis; if (ax < 0 || ax >= out_dim) { std::ostringstream msg; msg << "[expand_dims] Invalid axis " << axis << " for output array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } auto shape = a.shape(); shape.insert(shape.begin() + ax, 1); return reshape(a, std::move(shape), s); } array expand_dims( const array& a, const std::vector& axes, StreamOrDevice s /* = {} */) { { // Check for repeats std::set unique_axes(axes.begin(), axes.end()); if (unique_axes.size() != axes.size()) { throw std::invalid_argument("[expand_dims] Received duplicate axes."); } } int out_ndim = axes.size() + a.ndim(); std::vector canonical_axes = axes; for (auto& ax : canonical_axes) { ax = ax < 0 ? ax + out_ndim : ax; if (ax < 0 || ax >= out_ndim) { std::ostringstream msg; msg << "[expand_dims] Invalid axis " << ax << " for output array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } } // Check for repeats again std::set unique_axes(canonical_axes.begin(), canonical_axes.end()); if (unique_axes.size() != axes.size()) { throw std::invalid_argument("[expand_dims] Received duplicate axes."); } std::vector sorted_axes(unique_axes.begin(), unique_axes.end()); auto out_shape = a.shape(); for (int i = 0; i < sorted_axes.size(); ++i) { out_shape.insert(out_shape.begin() + sorted_axes[i], 1); } return reshape(a, std::move(out_shape), s); } // Slice helper namespace { inline auto normalize_slice( const std::vector& shape, std::vector& start, std::vector& stop, std::vector& strides) { std::vector out_shape(shape.size()); bool has_neg_strides = false; for (int i = 0; i < shape.size(); ++i) { // Following numpy docs // Negative i and j are interpreted as n + i and n + j where n is // the number of elements in the corresponding dimension. Negative // k makes stepping go towards smaller indices auto n = shape[i]; auto s = start[i]; s = s < 0 ? s + n : s; auto e = stop[i]; e = e < 0 ? e + n : e; // Note: -ve strides require start >= stop if (strides[i] < 0) { has_neg_strides = true; // Clamp to bounds auto st = std::min(s, n - 1); auto ed = std::max(-1, e); start[i] = st; stop[i] = ed > st ? st : ed; auto str = -strides[i]; out_shape[i] = (start[i] - stop[i] + str - 1) / str; } else { // Clamp to bounds auto st = std::max(0, std::min(s, n)); auto ed = std::max(0, std::min(e, n)); start[i] = st; stop[i] = ed < st ? st : ed; out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i]; } } return std::make_pair(has_neg_strides, out_shape); } } // namespace array slice( const array& a, std::vector start, std::vector stop, std::vector strides, StreamOrDevice s /* = {} */) { if (start.size() != a.ndim() || stop.size() != a.ndim() || strides.size() != a.ndim()) { std::ostringstream msg; msg << "[slice] Invalid number of indices or strides for " << "array with dimension " << a.ndim() << "."; throw std::invalid_argument(msg.str()); } auto [has_neg_strides, out_shape] = normalize_slice(a.shape(), start, stop, strides); if (!has_neg_strides && out_shape == a.shape()) { return a; } return array( out_shape, a.dtype(), std::make_shared( to_stream(s), std::move(start), std::move(stop), std::move(strides)), {a}); } array slice( const array& a, const std::vector& start, const std::vector& stop, StreamOrDevice s /* = {} */) { return slice(a, start, stop, std::vector(a.ndim(), 1), to_stream(s)); } /** Update a slice from the source array */ array slice_update( const array& src, const array& update, std::vector start, std::vector stop, std::vector strides, StreamOrDevice s /* = {} */) { // Check dimensions if (start.size() != src.ndim() || stop.size() != src.ndim() || strides.size() != src.ndim()) { std::ostringstream msg; msg << "[slice] Invalid number of indices or strides for " << "array with dimension " << src.ndim() << "."; throw std::invalid_argument(msg.str()); } // Process slice dimensions auto [has_neg_strides, upd_shape] = normalize_slice(src.shape(), start, stop, strides); // Broadcast update shape to slice shape auto update_broadcasted = broadcast_to(update, upd_shape, s); // If the entire src is the slice, just return the update if (!has_neg_strides && upd_shape == src.shape()) { return astype(update_broadcasted, src.dtype(), s); } return array( src.shape(), src.dtype(), std::make_shared( to_stream(s), std::move(start), std::move(stop), std::move(strides)), {src, update_broadcasted}); } /** Update a slice from the source array with stride 1 in each dimension */ array slice_update( const array& src, const array& update, std::vector start, std::vector stop, StreamOrDevice s /* = {} */) { auto strides = std::vector(src.ndim(), 1); return slice_update( src, update, std::move(start), std::move(stop), std::move(strides), s); } std::vector split( const array& a, const std::vector& indices, int axis, StreamOrDevice s /* = {} */) { auto ax = axis < 0 ? axis + a.ndim() : axis; if (ax < 0 || ax >= a.ndim()) { std::ostringstream msg; msg << "Invalid axis (" << axis << ") passed to split" << " for array with shape " << a.shape() << "."; throw std::invalid_argument(msg.str()); } if (indices.empty()) { return {a}; } if (indices.size() < 10 && std::is_sorted(indices.begin(), indices.end(), std::less<>{}) && indices[0] > 0 && indices.back() < a.shape(ax)) { std::vector dtypes(indices.size() + 1, a.dtype()); std::vector> shapes(indices.size() + 1, a.shape()); shapes[0][ax] = indices[0]; for (int i = 1; i < indices.size(); i++) { shapes[i][ax] = indices[i] - indices[i - 1]; } shapes.back()[ax] = a.shape(ax) - indices.back(); return array::make_arrays( std::move(shapes), dtypes, std::make_shared(to_stream(s), indices, ax), {a}); } std::vector res; auto out_shape = a.shape(); auto start_indices = std::vector(a.ndim(), 0); auto stop_indices = a.shape(); for (int i = 0; i < indices.size() + 1; ++i) { stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax); res.push_back(slice(a, start_indices, stop_indices, to_stream(s))); start_indices[ax] = stop_indices[ax]; } return res; } std::vector split( const array& a, const std::vector& indices, StreamOrDevice s /* = {} */) { return split(a, indices, 0, s); } std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) { auto ax = axis < 0 ? axis + a.ndim() : axis; if (ax < 0 || ax >= a.ndim()) { std::ostringstream msg; msg << "Invalid axis " << axis << " passed to split" << " for array with shape " << a.shape() << "."; throw std::invalid_argument(msg.str()); } auto q_and_r = std::ldiv(a.shape(axis), num_splits); if (q_and_r.rem) { std::ostringstream msg; msg << "Array split does not result in sub arrays with equal size:" << " attempting " << num_splits << " splits along axis " << axis << " for shape " << a.shape() << "."; throw std::invalid_argument(msg.str()); } auto split_size = q_and_r.quot; std::vector indices(num_splits - 1); for (int i = 0; i < indices.size(); ++i) { indices[i] = (i + 1) * split_size; } return split(a, indices, axis, s); } std::vector split(const array& a, int num_splits, StreamOrDevice s /* = {} */) { return split(a, num_splits, 0, to_stream(s)); } std::vector meshgrid( const std::vector& arrays, bool sparse /* = false */, std::string indexing /* = "xy" */, StreamOrDevice s /* = {} */) { if (indexing != "xy" && indexing != "ij") { throw std::invalid_argument( "[meshgrid] Invalid indexing value. Valid values are 'xy' and 'ij'."); } auto ndim = arrays.size(); std::vector outputs; for (int i = 0; i < ndim; ++i) { std::vector shape(ndim, 1); shape[i] = -1; outputs.push_back(reshape(arrays[i], std::move(shape), s)); } if (indexing == "xy" and ndim > 1) { std::vector shape(ndim, 1); shape[1] = arrays[0].size(); outputs[0] = reshape(arrays[0], shape, s); shape[1] = 1; shape[0] = arrays[1].size(); outputs[1] = reshape(arrays[1], std::move(shape), s); } if (!sparse) { outputs = broadcast_arrays(outputs, s); } return outputs; } array clip( const array& a, const std::optional& a_min, const std::optional& a_max, StreamOrDevice s /* = {} */) { if (!a_min.has_value() && !a_max.has_value()) { throw std::invalid_argument("At most one of a_min and a_max may be None"); } array result = astype(a, a.dtype(), s); if (a_min.has_value()) { result = maximum(result, a_min.value(), s); } if (a_max.has_value()) { result = minimum(result, a_max.value(), s); } return result; } array concatenate( const std::vector& arrays, int axis, StreamOrDevice s /* = {} */) { if (arrays.size() == 0) { throw std::invalid_argument( "[concatenate] No arrays provided for concatenation"); } // Normalize the given axis auto ax = axis < 0 ? axis + arrays[0].ndim() : axis; if (ax < 0 || ax >= arrays[0].ndim()) { std::ostringstream msg; msg << "[concatenate] Invalid axis (" << axis << ") passed to concatenate" << " for array with shape " << arrays[0].shape() << "."; throw std::invalid_argument(msg.str()); } auto throw_invalid_shapes = [&]() { std::ostringstream msg; msg << "[concatenate] All the input array dimensions must match exactly " << "except for the concatenation axis. However, the provided shapes are "; for (auto& a : arrays) { msg << a.shape() << ", "; } msg << "and the concatenation axis is " << axis << "."; throw std::invalid_argument(msg.str()); }; std::vector shape = arrays[0].shape(); shape[ax] = 0; // Make the output shape and validate that all arrays have the same shape // except for the concatenation axis. for (auto& a : arrays) { if (a.ndim() != shape.size()) { std::ostringstream msg; msg << "[concatenate] All the input arrays must have the same number of " << "dimensions. However, got arrays with dimensions " << shape.size() << " and " << a.ndim() << "."; throw std::invalid_argument(msg.str()); } for (int i = 0; i < a.ndim(); i++) { if (i == ax) { continue; } if (a.shape(i) != shape[i]) { throw_invalid_shapes(); } } shape[ax] += a.shape(ax); } // Promote all the arrays to the same type auto dtype = result_type(arrays); return array( std::move(shape), dtype, std::make_shared(to_stream(s), ax), std::move(arrays)); } array concatenate( const std::vector& arrays, StreamOrDevice s /* = {} */) { std::vector flat_inputs; for (auto& a : arrays) { flat_inputs.push_back(reshape(a, {-1}, s)); } return concatenate(flat_inputs, 0, s); } /** Stack arrays along a new axis */ array stack( const std::vector& arrays, int axis, StreamOrDevice s /* = {} */) { if (arrays.empty()) { throw std::invalid_argument("No arrays provided for stacking"); } if (!is_same_shape(arrays)) { throw std::invalid_argument("All arrays must have the same shape"); } int normalized_axis = normalize_axis(axis, arrays[0].ndim() + 1); std::vector new_arrays; new_arrays.reserve(arrays.size()); for (auto& a : arrays) { new_arrays.emplace_back(expand_dims(a, normalized_axis, s)); } return concatenate(new_arrays, axis, s); } array stack(const std::vector& arrays, StreamOrDevice s /* = {} */) { return stack(arrays, 0, s); } /** array repeat with axis */ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { axis = normalize_axis(axis, arr.ndim()); if (repeats < 0) { throw std::invalid_argument( "[repeat] Number of repeats cannot be negative"); } if (repeats == 0) { return array({}, arr.dtype()); } if (repeats == 1) { return arr; } // Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...) std::vector shape(arr.shape()); shape.insert(shape.begin() + axis + 1, repeats); array out = expand_dims(arr, axis + 1, s); out = broadcast_to(out, shape, s); // Reshape back into a contiguous array where S_axis is now S_axis * repeats shape.erase(shape.begin() + axis + 1); shape[axis] *= repeats; out = reshape(out, shape, s); return out; } array repeat(const array& arr, int repeats, StreamOrDevice s) { return repeat(flatten(arr, s), repeats, 0, s); } array tile( const array& arr, std::vector reps, StreamOrDevice s /* = {} */) { auto shape = arr.shape(); if (reps.size() < shape.size()) { reps.insert(reps.begin(), shape.size() - reps.size(), 1); } if (reps.size() > shape.size()) { shape.insert(shape.begin(), reps.size() - shape.size(), 1); } std::vector expand_shape; std::vector broad_shape; std::vector final_shape; for (int i = 0; i < shape.size(); i++) { if (reps[i] != 1) { expand_shape.push_back(1); broad_shape.push_back(reps[i]); } expand_shape.push_back(shape[i]); broad_shape.push_back(shape[i]); final_shape.push_back(reps[i] * shape[i]); } auto x = reshape(arr, expand_shape, s); x = broadcast_to(x, broad_shape, s); return reshape(x, final_shape, s); } /** Pad an array with a constant value */ array pad( const array& a, const std::vector& axes, const std::vector& low_pad_size, const std::vector& high_pad_size, const array& pad_value /*= array(0)*/, StreamOrDevice s /* = {}*/) { if (axes.size() != low_pad_size.size() || axes.size() != high_pad_size.size()) { std::ostringstream msg; msg << "Invalid number of padding sizes passed to pad " << "with axes of size " << axes.size(); throw std::invalid_argument(msg.str()); } std::vector out_shape = a.shape(); for (int i = 0; i < axes.size(); i++) { if (low_pad_size[i] < 0) { std::ostringstream msg; msg << "Invalid low padding size (" << low_pad_size[i] << ") passed to pad" << " for axis " << i << ". Padding sizes must be non-negative"; throw std::invalid_argument(msg.str()); } if (high_pad_size[i] < 0) { std::ostringstream msg; msg << "Invalid high padding size (" << high_pad_size[i] << ") passed to pad" << " for axis " << i << ". Padding sizes must be non-negative"; throw std::invalid_argument(msg.str()); } auto ax = axes[i] < 0 ? a.ndim() + axes[i] : axes[i]; out_shape[ax] += low_pad_size[i] + high_pad_size[i]; } return array( out_shape, a.dtype(), std::make_shared(to_stream(s), axes, low_pad_size, high_pad_size), {a, astype(pad_value, a.dtype(), s)}); } /** Pad an array with a constant value along all axes */ array pad( const array& a, const std::vector>& pad_width, const array& pad_value /*= array(0)*/, StreamOrDevice s /*= {}*/) { std::vector axes(a.ndim(), 0); std::iota(axes.begin(), axes.end(), 0); std::vector lows; std::vector highs; for (auto& pads : pad_width) { lows.push_back(pads.first); highs.push_back(pads.second); } return pad(a, axes, lows, highs, pad_value, s); } array pad( const array& a, const std::pair& pad_width, const array& pad_value /*= array(0)*/, StreamOrDevice s /*= {}*/) { return pad( a, std::vector>(a.ndim(), pad_width), pad_value, s); } array pad( const array& a, int pad_width, const array& pad_value /*= array(0)*/, StreamOrDevice s /*= {}*/) { return pad( a, std::vector>(a.ndim(), {pad_width, pad_width}), pad_value, s); } array moveaxis( const array& a, int source, int destination, StreamOrDevice s /* = {} */) { auto check_ax = [&a](int ax) { auto ndim = static_cast(a.ndim()); if (ax < -ndim || ax >= ndim) { std::ostringstream msg; msg << "[moveaxis] Invalid axis " << ax << " for array with " << ndim << " dimensions."; throw std::out_of_range(msg.str()); } return ax < 0 ? ax + ndim : ax; }; source = check_ax(source); destination = check_ax(destination); std::vector reorder(a.ndim()); std::iota(reorder.begin(), reorder.end(), 0); reorder.erase(reorder.begin() + source); reorder.insert(reorder.begin() + destination, source); return transpose(a, reorder, s); } array swapaxes( const array& a, int axis1, int axis2, StreamOrDevice s /* = {} */) { auto check_ax = [&a](int ax) { auto ndim = static_cast(a.ndim()); if (ax < -ndim || ax >= ndim) { std::ostringstream msg; msg << "[swapaxes] Invalid axis " << ax << " for array with " << ndim << " dimensions."; throw std::out_of_range(msg.str()); } return ax < 0 ? ax + ndim : ax; }; axis1 = check_ax(axis1); axis2 = check_ax(axis2); std::vector reorder(a.ndim()); std::iota(reorder.begin(), reorder.end(), 0); std::swap(reorder[axis1], reorder[axis2]); return transpose(a, std::move(reorder), s); } array transpose( const array& a, std::vector axes, StreamOrDevice s /* = {} */) { for (auto& ax : axes) { ax = ax < 0 ? ax + a.ndim() : ax; } if (axes.size() != a.ndim()) { std::ostringstream msg; msg << "[transpose] Recived " << axes.size() << " axes for array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } // Check in bounds and for duplicates std::vector shape(axes.size(), 0); for (auto& ax : axes) { if (ax < 0 || ax >= a.ndim()) { std::ostringstream msg; msg << "[transpose] Invalid axis (" << ax << ") for array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } if (shape[ax] != 0) { throw std::invalid_argument("[transpose] Repeat axes not allowed."); } shape[ax] = 1; } for (int i = 0; i < axes.size(); ++i) { shape[i] = a.shape()[axes[i]]; } return array( std::move(shape), a.dtype(), std::make_shared(to_stream(s), std::move(axes)), {a}); } array transpose(const array& a, StreamOrDevice s /* = {} */) { std::vector axes(a.ndim()); std::iota(axes.rbegin(), axes.rend(), 0); return transpose(a, std::move(axes), to_stream(s)); } array broadcast_to( const array& a, const std::vector& shape, StreamOrDevice s /* = {} */) { if (a.shape() == shape) { return a; } // Make sure the shapes are broadcastable auto bxshape = broadcast_shapes(a.shape(), shape); if (bxshape != shape) { std::ostringstream msg; msg << "Cannot broadcast array of shape " << a.shape() << " into shape " << shape << "."; throw std::invalid_argument(msg.str()); } return array( std::move(bxshape), a.dtype(), std::make_shared(to_stream(s), shape), {a}); } std::vector broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) { std::vector shape = broadcast_shapes(a.shape(), b.shape()); return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)}; } std::vector broadcast_arrays( const std::vector& inputs, StreamOrDevice s /* = {} */) { std::vector shape{}; for (const auto& in : inputs) { shape = broadcast_shapes(shape, in.shape()); } std::vector outputs; for (const auto& in : inputs) { outputs.push_back(broadcast_to(in, shape, s)); } return outputs; } array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array greater_equal( const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array array_equal( const array& a, const array& b, bool equal_nan, StreamOrDevice s /* = {} */) { if (a.shape() != b.shape()) { return array(false); } else { auto dtype = promote_types(a.dtype(), b.dtype()); equal_nan &= issubdtype(dtype, inexact); return all( array( a.shape(), bool_, std::make_shared(to_stream(s), equal_nan), {astype(a, dtype, s), astype(b, dtype, s)}), false, s); } } array isnan(const array& a, StreamOrDevice s /* = {} */) { if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { return full(a.shape(), false, bool_, s); } return not_equal(a, a, s); } array isinf(const array& a, StreamOrDevice s /* = {} */) { return logical_or(isposinf(a, s), isneginf(a, s), s); } array isposinf(const array& a, StreamOrDevice s /* = {} */) { if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { return full(a.shape(), false, bool_, s); } return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); } array isneginf(const array& a, StreamOrDevice s /* = {} */) { if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) { return full(a.shape(), false, bool_, s); } return equal(a, array(-std::numeric_limits::infinity(), a.dtype()), s); } array where( const array& a, const array& b, const array& c, StreamOrDevice s /* = {} */) { auto condition = astype(a, bool_, s); Dtype out_dtype = promote_types(b.dtype(), c.dtype()); auto inputs = broadcast_arrays( {condition, astype(b, out_dtype, s), astype(c, out_dtype, s)}, s); return array( inputs[0].shape(), out_dtype, std::make_shared