Fix irregular_strides benchmark shape type (#2754)

This commit is contained in:
wrmsr
2025-11-11 11:40:22 -08:00
committed by GitHub
parent 047114b988
commit 3fe2250c00

View File

@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
void time_irregular_binary_ops_4D() { void time_irregular_binary_ops_4D() {
auto device = mx::default_device(); auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512}; mx::Shape shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape); auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape); auto b = mx::random::uniform(shape);
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
void time_irregular_reshape() { void time_irregular_reshape() {
auto device = mx::default_device(); auto device = mx::default_device();
std::vector<int> shape; mx::Shape shape;
auto reshape_fn = [&shape, device](const mx::array& a) { auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device); return mx::reshape(a, shape, device);
}; };
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
void time_irregular_astype_2D() { void time_irregular_astype_2D() {
auto device = mx::default_device(); auto device = mx::default_device();
int size = 2048; int size = 2048;
std::vector<int> shape = {size, size}; mx::Shape shape = {size, size};
auto a = mx::random::uniform(shape); auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device); TIMEM("2D regular", mx::astype, a, mx::int32, device);