mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Fix irregular_strides benchmark shape type (#2754)
This commit is contained in:
@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
mx::Shape shape = {8, 8, 512, 512};
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
mx::Shape shape;
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
mx::Shape shape = {size, size};
|
||||
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
Reference in New Issue
Block a user