From 989e8bab66c07ae3d5481fe282b4bcaa256207b8 Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:34:10 +0800 Subject: [PATCH] feat: Add benchmarking for ReLUSquared activation function --- benchmarks/python/comparative/bench_mlx.py | 8 ++++++++ benchmarks/python/comparative/bench_torch.py | 10 ++++++++++ benchmarks/python/comparative/compare.py | 2 ++ 3 files changed, 20 insertions(+) diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 0821ccae6..4fe29a991 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -223,6 +223,11 @@ def relu6(x): y = nn.relu6(y) mx.eval(y) +def relu_squared(x): + y = x + for i in range(100): + y = nn.relu_squared(y) + mx.eval(y) def softplus(x): y = x @@ -458,6 +463,9 @@ if __name__ == "__main__": elif args.benchmark == "relu6": print(bench(relu6, x)) + elif args.benchmark == "relu_squared": + print(bench(relu_squared, x)) + elif args.benchmark == "celu": print(bench(celu, x)) diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a2157707b..c9a65a819 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -156,6 +156,13 @@ def relu6(x): y = torch.nn.functional.relu6(y) sync_if_needed(x) +@torch.no_grad() +def relu_squared(x): + y = x + for i in range(100): + y = torch.nn.functional.relu(y) + y = torch.square(y) + sync_if_needed(x) @torch.no_grad() def softplus(x): @@ -407,6 +414,9 @@ if __name__ == "__main__": elif args.benchmark == "relu6": print(bench(relu6, x)) + elif args.benchmark == "relu_squared": + print(bench(relu_squared, x)) + elif args.benchmark == "softplus": print(bench(softplus, x)) diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index 68b4a5bd3..e3270e903 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -207,6 +207,8 @@ if __name__ == "__main__": compare_filtered("elu --size 32x16x1024 --cpu") compare_filtered("relu6 --size 32x16x1024") compare_filtered("relu6 --size 32x16x1024 --cpu") + compare_filtered("relu_squared --size 32x16x1024") + compare_filtered("relu_squared --size 32x16x1024 --cpu") compare_filtered("softplus --size 32x16x1024") compare_filtered("softplus --size 32x16x1024 --cpu") compare_filtered("celu --size 32x16x1024")