Shape and Strides 1 / N (#1645)

* shape and stride type def

* more shape
This commit is contained in:
Awni Hannun
2024-12-05 12:53:43 -08:00
committed by GitHub
parent c5b0928c1f
commit fc88fd9097
6 changed files with 178 additions and 242 deletions

View File

@@ -66,9 +66,7 @@ Dtype result_type(const std::vector<array>& arrays) {
return t;
}
std::vector<int> broadcast_shapes(
const std::vector<int>& s1,
const std::vector<int>& s2) {
Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
// Use the same broadcasting rules as numpy
// https://numpy.org/doc/1.20/user/theory.broadcasting.html
// "The size of the trailing axes for both arrays in an operation must
@@ -79,7 +77,7 @@ std::vector<int> broadcast_shapes(
int diff = std::abs(ndim1 - ndim2);
const auto& big = ndim1 > ndim2 ? s1 : s2;
const auto& small = ndim1 > ndim2 ? s2 : s1;
std::vector<int> out_shape(ndim);
Shape out_shape(ndim);
for (int i = ndim - 1; i >= diff; --i) {
int a = big[i];
int b = small[i - diff];
@@ -158,10 +156,8 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
namespace {
inline size_t elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
inline size_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
size_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
@@ -199,7 +195,6 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
template <typename T>
void print_array(std::ostream& os, const array& a) {
std::vector<int> indices(a.ndim(), 0);
os << std::boolalpha;
os << "array(";
if (a.ndim() == 0) {
@@ -310,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
std::ostream& operator<<(std::ostream& os, const Shape& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
@@ -319,7 +314,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
std::ostream& operator<<(std::ostream& os, const Strides& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");