mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
awni's commit files
This commit is contained in:
198
benchmarks/cpp/irregular_strides.cpp
Normal file
198
benchmarks/cpp/irregular_strides.cpp
Normal file
@@ -0,0 +1,198 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = default_device();
|
||||
int size = 2048;
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
shape[i] = a.shape(i);
|
||||
}
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
|
||||
a = transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
time_irregular_binary_ops_4D();
|
||||
time_irregular_reshape();
|
||||
time_irregular_astype_1D();
|
||||
time_irregular_astype_2D();
|
||||
}
|
247
benchmarks/cpp/single_ops.cpp
Normal file
247
benchmarks/cpp/single_ops.cpp
Normal file
@@ -0,0 +1,247 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = default_device();
|
||||
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = default_device();
|
||||
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
};
|
||||
TIME(embedding_update);
|
||||
|
||||
auto embedding_add = [&a, &indices, &updates]() {
|
||||
return scatter_add(a, indices, updates, 0);
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
};
|
||||
TIME(single_element_update);
|
||||
|
||||
auto single_element_add = [&a, &indices, &updates]() {
|
||||
return scatter_add(a, indices, updates, 0);
|
||||
};
|
||||
TIME(single_element_add);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
time_binary_ops();
|
||||
time_strided_ops();
|
||||
time_random_generation();
|
||||
time_comparisons();
|
||||
time_matvec();
|
||||
time_matmul();
|
||||
time_reductions();
|
||||
time_gather_scatter();
|
||||
}
|
Reference in New Issue
Block a user