mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -617,7 +617,7 @@ TEST_CASE("test op vjps") {
|
||||
axes = {0};
|
||||
out = vjp(fun, array({}), array(3.0f)).second;
|
||||
CHECK_EQ(out.size(), 0);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(out.shape(), Shape{0});
|
||||
|
||||
axes = {0};
|
||||
out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}))
|
||||
@@ -725,9 +725,9 @@ TEST_CASE("test gather and take grads") {
|
||||
}
|
||||
|
||||
TEST_CASE("test slice grads") {
|
||||
std::vector<int> start = {5, 0, 0};
|
||||
std::vector<int> stop = {7, 2, 4};
|
||||
std::vector<int> strides = {1, 1, 1};
|
||||
Shape start = {5, 0, 0};
|
||||
Shape stop = {7, 2, 4};
|
||||
Shape strides = {1, 1, 1};
|
||||
|
||||
auto fn = [&start, &stop, &strides](array input) {
|
||||
return slice(input, start, stop, strides);
|
||||
@@ -982,8 +982,8 @@ TEST_CASE("test comparison grads") {
|
||||
|
||||
TEST_CASE("test as_strided grads") {
|
||||
auto x = ones({11});
|
||||
std::vector<int> shape = {5, 5};
|
||||
std::vector<size_t> strides = {1, 1};
|
||||
Shape shape = {5, 5};
|
||||
Strides strides = {1, 1};
|
||||
size_t offset = 0;
|
||||
|
||||
auto fun = [&shape, &strides, &offset](array x) {
|
||||
|
Reference in New Issue
Block a user