Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -27,10 +27,18 @@ struct ndarray_traits<float16_t> {
static constexpr dlpack::dtype bfloat16{4, 16, 1};
}; // namespace nanobind
int check_shape_dim(int64_t dim) {
if (dim > std::numeric_limits<int>::max()) {
throw std::invalid_argument(
"Shape dimension falls outside supported `int` range.");
}
return static_cast<int>(dim);
}
template <typename T>
array nd_array_to_mlx_contiguous(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
const std::vector<int>& shape,
const Shape& shape,
Dtype dtype) {
// Make a copy of the numpy buffer
// Get buffer ptr pass to array constructor
@@ -42,7 +50,7 @@ array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype) {
// Compute the shape and size
std::vector<int> shape;
Shape shape;
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
@@ -108,13 +116,12 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
a.eval();
}
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
return nb::ndarray<NDParams...>(
a.data<T>(),
a.ndim(),
shape.data(),
/* owner= */ nb::none(),
strides.data(),
a.strides().data(),
t.value_or(nb::dtype<T>()));
}
@@ -272,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
template <typename T>
PyScalarT validate_shape(
T list,
const std::vector<int>& shape,
const Shape& shape,
int idx,
bool& all_python_primitive_elements) {
if (idx >= shape.size()) {
@@ -340,7 +347,7 @@ PyScalarT validate_shape(
}
template <typename T>
void get_shape(T list, std::vector<int>& shape) {
void get_shape(T list, Shape& shape) {
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
@@ -351,7 +358,7 @@ void get_shape(T list, std::vector<int>& shape) {
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(check_shape_dim(arr.shape(i)));
shape.push_back(arr.shape(i));
}
return;
}
@@ -363,7 +370,7 @@ array array_from_list_impl(
T pl,
const PyScalarT& inferred_type,
std::optional<Dtype> specified_type,
const std::vector<int>& shape) {
const Shape& shape) {
// Make the array
switch (inferred_type) {
case pybool: {
@@ -420,7 +427,7 @@ array array_from_list_impl(
template <typename T>
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
// Compute the shape
std::vector<int> shape;
Shape shape;
get_shape(pl, shape);
// Validate the shape and type