mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
feat: Add benchmarking for ReLUSquared activation function
This commit is contained in:
parent
fe0672a9d2
commit
989e8bab66
@ -223,6 +223,11 @@ def relu6(x):
|
||||
y = nn.relu6(y)
|
||||
mx.eval(y)
|
||||
|
||||
def relu_squared(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.relu_squared(y)
|
||||
mx.eval(y)
|
||||
|
||||
def softplus(x):
|
||||
y = x
|
||||
@ -458,6 +463,9 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "relu_squared":
|
||||
print(bench(relu_squared, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
|
@ -156,6 +156,13 @@ def relu6(x):
|
||||
y = torch.nn.functional.relu6(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
@torch.no_grad()
|
||||
def relu_squared(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.relu(y)
|
||||
y = torch.square(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
@torch.no_grad()
|
||||
def softplus(x):
|
||||
@ -407,6 +414,9 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "relu_squared":
|
||||
print(bench(relu_squared, x))
|
||||
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
|
||||
|
@ -207,6 +207,8 @@ if __name__ == "__main__":
|
||||
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||
compare_filtered("relu6 --size 32x16x1024")
|
||||
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||
compare_filtered("relu_squared --size 32x16x1024")
|
||||
compare_filtered("relu_squared --size 32x16x1024 --cpu")
|
||||
compare_filtered("softplus --size 32x16x1024")
|
||||
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||
compare_filtered("celu --size 32x16x1024")
|
||||
|
Loading…
Reference in New Issue
Block a user