Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -27,10 +27,18 @@ struct ndarray_traits<float16_t> {
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
int check_shape_dim(int64_t dim) {
if (dim > std::numeric_limits<int>::max()) {
throw std::invalid_argument(
"Shape dimension falls outside supported `int` range.");
}
return static_cast<int>(dim);
}
template <typename T>
array nd_array_to_mlx_contiguous(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
const std::vector<int>& shape,
const Shape& shape,
Dtype dtype) {
// Make a copy of the numpy buffer
// Get buffer ptr pass to array constructor
@@ -42,7 +50,7 @@ array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype) {
// Compute the shape and size
std::vector<int> shape;
Shape shape;
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
@@ -108,13 +116,12 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
a.eval();
}
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
return nb::ndarray<NDParams...>(
a.data<T>(),
a.ndim(),
shape.data(),
/* owner= */ nb::none(),
strides.data(),
a.strides().data(),
t.value_or(nb::dtype<T>()));
}
@@ -272,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
template <typename T>
PyScalarT validate_shape(
T list,
const std::vector<int>& shape,
const Shape& shape,
int idx,
bool& all_python_primitive_elements) {
if (idx >= shape.size()) {
@@ -340,7 +347,7 @@ PyScalarT validate_shape(
}
template <typename T>
void get_shape(T list, std::vector<int>& shape) {
void get_shape(T list, Shape& shape) {
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
@@ -351,7 +358,7 @@ void get_shape(T list, std::vector<int>& shape) {
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(check_shape_dim(arr.shape(i)));
shape.push_back(arr.shape(i));
}
return;
}
@@ -363,7 +370,7 @@ array array_from_list_impl(
T pl,
const PyScalarT& inferred_type,
std::optional<Dtype> specified_type,
const std::vector<int>& shape) {
const Shape& shape) {
// Make the array
switch (inferred_type) {
case pybool: {
@@ -420,7 +427,7 @@ array array_from_list_impl(
template <typename T>
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
// Compute the shape
std::vector<int> shape;
Shape shape;
get_shape(pl, shape);
// Validate the shape and type

View File

@@ -2953,16 +2953,16 @@ void init_ops(nb::module_& m) {
m.def(
"as_strided",
[](const array& a,
std::optional<std::vector<int>> shape,
std::optional<std::vector<size_t>> strides,
std::optional<Shape> shape,
std::optional<Strides> strides,
size_t offset,
StreamOrDevice s) {
std::vector<int> a_shape = (shape) ? *shape : a.shape();
std::vector<size_t> a_strides;
auto a_shape = (shape) ? *shape : a.shape();
Strides a_strides;
if (strides) {
a_strides = *strides;
} else {
a_strides = std::vector<size_t>(a_shape.size(), 1);
a_strides = Strides(a_shape.size(), 1);
for (int i = a_shape.size() - 1; i > 0; i--) {
a_strides[i - 1] = a_shape[i] * a_strides[i];
}