feat: Add benchmarking for ReLUSquared activation function

This commit is contained in:
John Mai 2025-06-15 17:34:10 +08:00
parent fe0672a9d2
commit 989e8bab66
3 changed files with 20 additions and 0 deletions

View File

@ -223,6 +223,11 @@ def relu6(x):
y = nn.relu6(y)
mx.eval(y)
def relu_squared(x):
y = x
for i in range(100):
y = nn.relu_squared(y)
mx.eval(y)
def softplus(x):
y = x
@ -458,6 +463,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "celu":
print(bench(celu, x))

View File

@ -156,6 +156,13 @@ def relu6(x):
y = torch.nn.functional.relu6(y)
sync_if_needed(x)
@torch.no_grad()
def relu_squared(x):
y = x
for i in range(100):
y = torch.nn.functional.relu(y)
y = torch.square(y)
sync_if_needed(x)
@torch.no_grad()
def softplus(x):
@ -407,6 +414,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))

View File

@ -207,6 +207,8 @@ if __name__ == "__main__":
compare_filtered("elu --size 32x16x1024 --cpu")
compare_filtered("relu6 --size 32x16x1024")
compare_filtered("relu6 --size 32x16x1024 --cpu")
compare_filtered("relu_squared --size 32x16x1024")
compare_filtered("relu_squared --size 32x16x1024 --cpu")
compare_filtered("softplus --size 32x16x1024")
compare_filtered("softplus --size 32x16x1024 --cpu")
compare_filtered("celu --size 32x16x1024")