[mlx-lm] Add precompiled normalizations (#451)

* add precompiled normalizations

* nits
This commit is contained in:
Awni Hannun 2024-02-22 12:40:55 -08:00 committed by GitHub
parent 97c09a863d
commit f24edfa9dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 74 additions and 105 deletions

View File

@ -1,4 +1,4 @@
from .convert import convert from .convert import convert
from .utils import generate, load from .utils import generate, load
__version__ = "0.0.13" __version__ = "0.0.14"

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
@ -22,18 +23,21 @@ class ModelArgs(BaseModelArgs):
rope_traditional: bool = False rope_traditional: bool = False
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (1.0 + weight) * x.astype(weight.dtype)
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5): def __init__(self, dims: int, eps: float = 1e-5):
super().__init__() super().__init__()
self.weight = mx.ones((dims,)) self.weight = mx.ones((dims,))
self.eps = eps self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x): def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype) return rms_norm(x, self.weight, self.eps)
return (1 + self.weight) * output
class Attention(nn.Module): class Attention(nn.Module):

View File

@ -0,0 +1,51 @@
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return rms_norm(x, self.weight, self.eps)
@partial(mx.compile, shapeless=True)
def ln_norm(x, eps, weight=None, bias=None):
t = x.dtype
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + eps)
x = x.astype(t)
return weight * x + bias if weight is not None else x
class LayerNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
if "weight" in self:
return ln_norm(x, self.eps, self.weight, self.bias)
else:
return ln_norm(x, self.eps)

View File

@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
raise ValueError("rope_scaling 'type' currently only supports 'linear'") raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -30,20 +31,6 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class MixtralAttention(nn.Module): class MixtralAttention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import LayerNorm
try: try:
import hf_olmo import hf_olmo
@ -37,11 +38,6 @@ class ModelArgs(BaseModelArgs):
) )
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass @dataclass
@ -27,11 +28,6 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads self.num_key_value_heads = self.num_attention_heads
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class PhiAttention(nn.Module): class PhiAttention(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()

View File

@ -1,17 +1,13 @@
import glob
import inspect import inspect
import json
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass
from pathlib import Path from typing import Tuple
from typing import Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten from .layers import LayerNorm
from transformers import AutoTokenizer
@dataclass @dataclass
@ -37,11 +33,6 @@ class ModelArgs:
) )
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module): class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int): def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__() super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -22,20 +23,6 @@ class ModelArgs(BaseModelArgs):
rope_traditional: bool = False rope_traditional: bool = False
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.variance_epsilon = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.variance_epsilon)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()

View File

@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -26,20 +27,6 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()

View File

@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
raise ValueError("rope_scaling 'type' currently only supports 'linear'") raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-6):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass @dataclass
@ -24,11 +25,6 @@ class ModelArgs(BaseModelArgs):
use_qkv_bias: bool use_qkv_bias: bool
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()

View File

@ -1,4 +1,4 @@
mlx>=0.1 mlx>=0.4
numpy numpy
transformers>=4.38.0 transformers>=4.38.0
protobuf protobuf