diff --git a/benchmarks/cpp/irregular_strides.cpp b/benchmarks/cpp/irregular_strides.cpp index 552461335..cc4e975c9 100644 --- a/benchmarks/cpp/irregular_strides.cpp +++ b/benchmarks/cpp/irregular_strides.cpp @@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() { void time_irregular_binary_ops_4D() { auto device = mx::default_device(); - std::vector shape = {8, 8, 512, 512}; + mx::Shape shape = {8, 8, 512, 512}; auto a = mx::random::uniform(shape); auto b = mx::random::uniform(shape); @@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() { void time_irregular_reshape() { auto device = mx::default_device(); - std::vector shape; + mx::Shape shape; auto reshape_fn = [&shape, device](const mx::array& a) { return mx::reshape(a, shape, device); }; @@ -170,7 +170,7 @@ void time_irregular_astype_1D() { void time_irregular_astype_2D() { auto device = mx::default_device(); int size = 2048; - std::vector shape = {size, size}; + mx::Shape shape = {size, size}; auto a = mx::random::uniform(shape); TIMEM("2D regular", mx::astype, a, mx::int32, device);