update: format code

This commit is contained in:
John Mai 2025-06-15 17:35:33 +08:00
parent 989e8bab66
commit b3c1aaafd2
2 changed files with 4 additions and 0 deletions

View File

@ -223,12 +223,14 @@ def relu6(x):
y = nn.relu6(y)
mx.eval(y)
def relu_squared(x):
y = x
for i in range(100):
y = nn.relu_squared(y)
mx.eval(y)
def softplus(x):
y = x
for i in range(100):

View File

@ -156,6 +156,7 @@ def relu6(x):
y = torch.nn.functional.relu6(y)
sync_if_needed(x)
@torch.no_grad()
def relu_squared(x):
y = x
@ -164,6 +165,7 @@ def relu_squared(x):
y = torch.square(y)
sync_if_needed(x)
@torch.no_grad()
def softplus(x):
y = x