mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Use fast rope (#945)
* use fast rope * fix llama * use fast rope for llama3.1 * requires unreleased mlx * fix su * fix deepseek v2 * only one of base or freqs * nit * fix * hard code freqs
This commit is contained in:
@@ -110,7 +110,7 @@ class Transformer2D(nn.Module):
|
||||
|
||||
# Perform the input norm and projection
|
||||
B, H, W, C = x.shape
|
||||
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
|
||||
x = self.norm(x).reshape(B, -1, C)
|
||||
x = self.proj_in(x)
|
||||
|
||||
# Apply the transformer
|
||||
@@ -156,12 +156,12 @@ class ResnetBlock2D(nn.Module):
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(nn.silu(temb))
|
||||
|
||||
y = self.norm1(x.astype(mx.float32)).astype(dtype)
|
||||
y = self.norm1(x)
|
||||
y = nn.silu(y)
|
||||
y = self.conv1(y)
|
||||
if temb is not None:
|
||||
y = y + temb[:, None, None, :]
|
||||
y = self.norm2(y.astype(mx.float32)).astype(dtype)
|
||||
y = self.norm2(y)
|
||||
y = nn.silu(y)
|
||||
y = self.conv2(y)
|
||||
|
||||
@@ -453,8 +453,7 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
# Postprocess the output
|
||||
dtype = x.dtype
|
||||
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
|
Reference in New Issue
Block a user