Use fast rope (#945)

* use fast rope

* fix llama

* use fast rope for llama3.1

* requires unreleased mlx

* fix su

* fix deepseek v2

* only one of base or freqs

* nit

* fix

* hard code freqs
This commit is contained in:
Awni Hannun
2024-08-23 13:18:51 -07:00
committed by GitHub
parent 58591a1b41
commit 6731254e76
7 changed files with 65 additions and 137 deletions

View File

@@ -110,7 +110,7 @@ class Transformer2D(nn.Module):
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
x = self.norm(x).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
@@ -156,12 +156,12 @@ class ResnetBlock2D(nn.Module):
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x.astype(mx.float32)).astype(dtype)
y = self.norm1(x)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y.astype(mx.float32)).astype(dtype)
y = self.norm2(y)
y = nn.silu(y)
y = self.conv2(y)
@@ -453,8 +453,7 @@ class UNetModel(nn.Module):
)
# Postprocess the output
dtype = x.dtype
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)