mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -27,10 +27,18 @@ struct ndarray_traits<float16_t> {
|
||||
static constexpr dlpack::dtype bfloat16{4, 16, 1};
|
||||
}; // namespace nanobind
|
||||
|
||||
int check_shape_dim(int64_t dim) {
|
||||
if (dim > std::numeric_limits<int>::max()) {
|
||||
throw std::invalid_argument(
|
||||
"Shape dimension falls outside supported `int` range.");
|
||||
}
|
||||
return static_cast<int>(dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
array nd_array_to_mlx_contiguous(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
Dtype dtype) {
|
||||
// Make a copy of the numpy buffer
|
||||
// Get buffer ptr pass to array constructor
|
||||
@@ -42,7 +50,7 @@ array nd_array_to_mlx(
|
||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||
std::optional<Dtype> dtype) {
|
||||
// Compute the shape and size
|
||||
std::vector<int> shape;
|
||||
Shape shape;
|
||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||
}
|
||||
@@ -108,13 +116,12 @@ nb::ndarray<NDParams...> mlx_to_nd_array_impl(
|
||||
a.eval();
|
||||
}
|
||||
std::vector<size_t> shape(a.shape().begin(), a.shape().end());
|
||||
std::vector<int64_t> strides(a.strides().begin(), a.strides().end());
|
||||
return nb::ndarray<NDParams...>(
|
||||
a.data<T>(),
|
||||
a.ndim(),
|
||||
shape.data(),
|
||||
/* owner= */ nb::none(),
|
||||
strides.data(),
|
||||
a.strides().data(),
|
||||
t.value_or(nb::dtype<T>()));
|
||||
}
|
||||
|
||||
@@ -272,7 +279,7 @@ void fill_vector(T list, std::vector<U>& vals) {
|
||||
template <typename T>
|
||||
PyScalarT validate_shape(
|
||||
T list,
|
||||
const std::vector<int>& shape,
|
||||
const Shape& shape,
|
||||
int idx,
|
||||
bool& all_python_primitive_elements) {
|
||||
if (idx >= shape.size()) {
|
||||
@@ -340,7 +347,7 @@ PyScalarT validate_shape(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void get_shape(T list, std::vector<int>& shape) {
|
||||
void get_shape(T list, Shape& shape) {
|
||||
shape.push_back(check_shape_dim(nb::len(list)));
|
||||
if (shape.back() > 0) {
|
||||
auto l = list.begin();
|
||||
@@ -351,7 +358,7 @@ void get_shape(T list, std::vector<int>& shape) {
|
||||
} else if (nb::isinstance<array>(*l)) {
|
||||
auto arr = nb::cast<array>(*l);
|
||||
for (int i = 0; i < arr.ndim(); i++) {
|
||||
shape.push_back(check_shape_dim(arr.shape(i)));
|
||||
shape.push_back(arr.shape(i));
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -363,7 +370,7 @@ array array_from_list_impl(
|
||||
T pl,
|
||||
const PyScalarT& inferred_type,
|
||||
std::optional<Dtype> specified_type,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
// Make the array
|
||||
switch (inferred_type) {
|
||||
case pybool: {
|
||||
@@ -420,7 +427,7 @@ array array_from_list_impl(
|
||||
template <typename T>
|
||||
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
|
||||
// Compute the shape
|
||||
std::vector<int> shape;
|
||||
Shape shape;
|
||||
get_shape(pl, shape);
|
||||
|
||||
// Validate the shape and type
|
||||
|
@@ -2953,16 +2953,16 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"as_strided",
|
||||
[](const array& a,
|
||||
std::optional<std::vector<int>> shape,
|
||||
std::optional<std::vector<size_t>> strides,
|
||||
std::optional<Shape> shape,
|
||||
std::optional<Strides> strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> a_shape = (shape) ? *shape : a.shape();
|
||||
std::vector<size_t> a_strides;
|
||||
auto a_shape = (shape) ? *shape : a.shape();
|
||||
Strides a_strides;
|
||||
if (strides) {
|
||||
a_strides = *strides;
|
||||
} else {
|
||||
a_strides = std::vector<size_t>(a_shape.size(), 1);
|
||||
a_strides = Strides(a_shape.size(), 1);
|
||||
for (int i = a_shape.size() - 1; i > 0; i--) {
|
||||
a_strides[i - 1] = a_shape[i] * a_strides[i];
|
||||
}
|
||||
|
Reference in New Issue
Block a user