diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 0821ccae6..4fe29a991 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -223,6 +223,11 @@ def relu6(x): y = nn.relu6(y) mx.eval(y) +def relu_squared(x): + y = x + for i in range(100): + y = nn.relu_squared(y) + mx.eval(y) def softplus(x): y = x @@ -458,6 +463,9 @@ if __name__ == "__main__": elif args.benchmark == "relu6": print(bench(relu6, x)) + elif args.benchmark == "relu_squared": + print(bench(relu_squared, x)) + elif args.benchmark == "celu": print(bench(celu, x)) diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a2157707b..c9a65a819 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -156,6 +156,13 @@ def relu6(x): y = torch.nn.functional.relu6(y) sync_if_needed(x) +@torch.no_grad() +def relu_squared(x): + y = x + for i in range(100): + y = torch.nn.functional.relu(y) + y = torch.square(y) + sync_if_needed(x) @torch.no_grad() def softplus(x): @@ -407,6 +414,9 @@ if __name__ == "__main__": elif args.benchmark == "relu6": print(bench(relu6, x)) + elif args.benchmark == "relu_squared": + print(bench(relu_squared, x)) + elif args.benchmark == "softplus": print(bench(softplus, x)) diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index 68b4a5bd3..e3270e903 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -207,6 +207,8 @@ if __name__ == "__main__": compare_filtered("elu --size 32x16x1024 --cpu") compare_filtered("relu6 --size 32x16x1024") compare_filtered("relu6 --size 32x16x1024 --cpu") + compare_filtered("relu_squared --size 32x16x1024") + compare_filtered("relu_squared --size 32x16x1024 --cpu") compare_filtered("softplus --size 32x16x1024") compare_filtered("softplus --size 32x16x1024 --cpu") compare_filtered("celu --size 32x16x1024")