mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
update: format code
This commit is contained in:
@@ -223,12 +223,14 @@ def relu6(x):
|
|||||||
y = nn.relu6(y)
|
y = nn.relu6(y)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
def relu_squared(x):
|
def relu_squared(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
y = nn.relu_squared(y)
|
y = nn.relu_squared(y)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
|
|||||||
@@ -156,6 +156,7 @@ def relu6(x):
|
|||||||
y = torch.nn.functional.relu6(y)
|
y = torch.nn.functional.relu6(y)
|
||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def relu_squared(x):
|
def relu_squared(x):
|
||||||
y = x
|
y = x
|
||||||
@@ -164,6 +165,7 @@ def relu_squared(x):
|
|||||||
y = torch.square(y)
|
y = torch.square(y)
|
||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
|
|||||||
Reference in New Issue
Block a user