From b3c1aaafd25e5b270127ab1940648c8d92dc4dbb Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:35:33 +0800 Subject: [PATCH] update: format code --- benchmarks/python/comparative/bench_mlx.py | 2 ++ benchmarks/python/comparative/bench_torch.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 4fe29a991..658cc6f95 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -223,12 +223,14 @@ def relu6(x): y = nn.relu6(y) mx.eval(y) + def relu_squared(x): y = x for i in range(100): y = nn.relu_squared(y) mx.eval(y) + def softplus(x): y = x for i in range(100): diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index c9a65a819..6c5c518f6 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -156,6 +156,7 @@ def relu6(x): y = torch.nn.functional.relu6(y) sync_if_needed(x) + @torch.no_grad() def relu_squared(x): y = x @@ -164,6 +165,7 @@ def relu_squared(x): y = torch.square(y) sync_if_needed(x) + @torch.no_grad() def softplus(x): y = x