Shape and Strides 1 / N (#1645)

* shape and stride type def

* more shape
This commit is contained in:
Awni Hannun
2024-12-05 12:53:43 -08:00
committed by GitHub
parent c5b0928c1f
commit fc88fd9097
6 changed files with 178 additions and 242 deletions

View File

@@ -62,9 +62,7 @@ inline Dtype result_type(const array& a, const array& b, const array& c) {
}
Dtype result_type(const std::vector<array>& arrays);
std::vector<int> broadcast_shapes(
const std::vector<int>& s1,
const std::vector<int>& s2);
Shape broadcast_shapes(const Shape& s1, const Shape& s2);
bool is_same_shape(const std::vector<array>& arrays);
@@ -96,8 +94,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
std::ostream& operator<<(std::ostream& os, const Shape& v);
std::ostream& operator<<(std::ostream& os, const Strides& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";