mlx/benchmarks/cpp/single_ops.cpp

250 lines
6.4 KiB
C++
Raw Normal View History

2023-12-01 03:12:53 +08:00
// Copyright © 2023 Apple Inc.
2023-11-30 02:30:41 +08:00
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
void time_creation_ops() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto full_fp32 = [&]() { return full(shape, 3.3f); };
TIME(full_fp32);
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
TIME(zeros_fp32);
auto ones_fp32 = [&]() { return ones(shape, float32); };
TIME(ones_fp32);
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32);
}
void time_type_conversions() {
int M = 2000;
int N = 500;
auto shape = {M, N};
auto device = default_device();
auto a = zeros(shape, float32);
eval(a);
TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device);
a = zeros(shape, int32);
eval(a);
TIMEM("int32 to float32", astype, a, float32, device);
a = zeros(shape, bool_);
eval(a);
TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to uint32", astype, a, uint32, device);
}
void time_random_generation() {
int M = 2000;
int N = 500;
auto uniform = [&]() { return random::uniform({M, N}, float32); };
TIME(uniform);
auto normal = [&]() { return random::normal({M, N}, float32); };
TIME(normal);
}
void time_unary_ops() {
int M = 2000;
int N = 500;
auto device = default_device();
auto a = random::normal({M, N});
eval(a);
TIME(mlx::core::abs, a, device);
TIME(negative, a, device);
TIME(sign, a, device);
TIME(square, a, device);
TIME(mlx::core::sqrt, a, device);
TIME(rsqrt, a, device);
TIME(mlx::core::exp, a, device);
a = random::uniform({M, N});
TIME(mlx::core::log, a, device);
}
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(add, a, b, device);
TIME(subtract, a, b, device);
TIME(multiply, a, b, device);
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
}
void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50;
auto a = random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P});
auto device = default_device();
eval(a, b);
TIMEM("non-strided", add, a, b, device);
a = transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1});
eval(a, b);
TIMEM("strided", add, a, b, device);
}
void time_comparisons() {
int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
eval(a, b);
TIME(equal, a, b, device);
TIME(greater, a, b, device);
TIME(greater_equal, a, b, device);
TIME(less, a, b, device);
TIME(less_equal, a, b, device);
}
void time_matvec() {
int M = 2000, N = 200;
auto a = random::uniform({M, N});
auto b = random::uniform({N});
auto c = random::uniform({M});
eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); };
TIME(matvec);
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
TIME(matvec_transpose);
}
void time_matmul() {
int M = 1000, N = 1000, K = 1000;
auto a = random::uniform({M, K});
auto b = random::uniform({K, N});
auto device = default_device();
eval(a, b);
TIME(matmul, a, b, device);
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
TIME(transpose_matmul);
}
void time_reductions() {
auto a = random::normal({10000, 1000});
eval(a);
auto sum_all = [&a]() { return sum(a, false); };
TIME(sum_all);
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
TIME(sum_along_0);
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
TIME(sum_along_1);
auto prod_all = [&a]() { return prod(a, false); };
TIME(prod_all);
auto all_true = [&a]() { return all(a, false); };
TIME(all_true);
auto all_along_0 = [&a]() { return all(a, 0, false); };
TIME(all_along_0);
auto all_along_1 = [&a]() { return all(a, 1, false); };
TIME(all_along_1);
auto any_true = [&a]() { return any(a, false); };
TIME(any_true);
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
TIME(argmin_along_1);
}
void time_gather_scatter() {
auto a = random::normal({1000, 768});
eval(a);
auto indices = random::randint(0, 1000, {256});
eval(indices);
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
TIME(embedding_lookup);
indices = random::randint(0, 768 * 1000, {256 * 768});
eval(indices);
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
TIME(single_element_lookup);
indices = random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768});
eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
};
TIME(embedding_update);
auto embedding_add = [&a, &indices, &updates]() {
return scatter_add(a, indices, updates, 0);
};
TIME(embedding_add);
a = reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1});
eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0);
};
TIME(single_element_update);
auto single_element_add = [&a, &indices, &updates]() {
return scatter_add(a, indices, updates, 0);
};
TIME(single_element_add);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops();
time_type_conversions();
time_unary_ops();
time_binary_ops();
time_strided_ops();
time_random_generation();
time_comparisons();
time_matvec();
time_matmul();
time_reductions();
time_gather_scatter();
}