mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Shape and Strides 1 / N (#1645)
* shape and stride type def * more shape
This commit is contained in:
@@ -66,9 +66,7 @@ Dtype result_type(const std::vector<array>& arrays) {
|
||||
return t;
|
||||
}
|
||||
|
||||
std::vector<int> broadcast_shapes(
|
||||
const std::vector<int>& s1,
|
||||
const std::vector<int>& s2) {
|
||||
Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
||||
// Use the same broadcasting rules as numpy
|
||||
// https://numpy.org/doc/1.20/user/theory.broadcasting.html
|
||||
// "The size of the trailing axes for both arrays in an operation must
|
||||
@@ -79,7 +77,7 @@ std::vector<int> broadcast_shapes(
|
||||
int diff = std::abs(ndim1 - ndim2);
|
||||
const auto& big = ndim1 > ndim2 ? s1 : s2;
|
||||
const auto& small = ndim1 > ndim2 ? s2 : s1;
|
||||
std::vector<int> out_shape(ndim);
|
||||
Shape out_shape(ndim);
|
||||
for (int i = ndim - 1; i >= diff; --i) {
|
||||
int a = big[i];
|
||||
int b = small[i - diff];
|
||||
@@ -158,10 +156,8 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
|
||||
|
||||
namespace {
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
inline size_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
size_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
@@ -199,7 +195,6 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
|
||||
|
||||
template <typename T>
|
||||
void print_array(std::ostream& os, const array& a) {
|
||||
std::vector<int> indices(a.ndim(), 0);
|
||||
os << std::boolalpha;
|
||||
os << "array(";
|
||||
if (a.ndim() == 0) {
|
||||
@@ -310,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||
std::ostream& operator<<(std::ostream& os, const Shape& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
@@ -319,7 +314,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
|
||||
std::ostream& operator<<(std::ostream& os, const Strides& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
|
||||
Reference in New Issue
Block a user