Fix irregular_strides benchmark shape type (#2754)

This commit is contained in:
wrmsr
2025-11-11 11:40:22 -08:00
committed by GitHub
parent 047114b988
commit 3fe2250c00

View File

@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
void time_irregular_binary_ops_4D() {
auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512};
mx::Shape shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
void time_irregular_reshape() {
auto device = mx::default_device();
std::vector<int> shape;
mx::Shape shape;
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
};
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
void time_irregular_astype_2D() {
auto device = mx::default_device();
int size = 2048;
std::vector<int> shape = {size, size};
mx::Shape shape = {size, size};
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);