mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -78,11 +78,11 @@ void ternary_op_dims(
|
||||
const T3* c,
|
||||
U* out,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& c_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& c_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
@@ -164,10 +164,10 @@ void ternary_op_dispatch_dims(
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
|
||||
size_t stride = out_strides[ndim - 3];
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
|
||||
Reference in New Issue
Block a user