mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
[mlx-lm] Add precompiled normalizations (#451)
* add precompiled normalizations * nits
This commit is contained in:
parent
97c09a863d
commit
f24edfa9dc
@ -1,4 +1,4 @@
|
||||
from .convert import convert
|
||||
from .utils import generate, load
|
||||
|
||||
__version__ = "0.0.13"
|
||||
__version__ = "0.0.14"
|
||||
|
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
@ -22,18 +23,21 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (1.0 + weight) * x.astype(weight.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return (1 + self.weight) * output
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
51
llms/mlx_lm/models/layers.py
Normal file
51
llms/mlx_lm/models/layers.py
Normal file
@ -0,0 +1,51 @@
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return weight * x.astype(weight.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def ln_norm(x, eps, weight=None, bias=None):
|
||||
t = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + eps)
|
||||
x = x.astype(t)
|
||||
return weight * x + bias if weight is not None else x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if "weight" in self:
|
||||
return ln_norm(x, self.eps, self.weight, self.bias)
|
||||
else:
|
||||
return ln_norm(x, self.eps)
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -6,6 +6,7 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -30,20 +31,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -6,6 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
@ -37,11 +38,6 @@ class ModelArgs(BaseModelArgs):
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -6,6 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -27,11 +28,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class PhiAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -1,17 +1,13 @@
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -37,11 +33,6 @@ class ModelArgs:
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
@ -6,6 +6,7 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -22,20 +23,6 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.variance_epsilon)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -26,20 +27,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -6,6 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -24,11 +25,6 @@ class ModelArgs(BaseModelArgs):
|
||||
use_qkv_bias: bool
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.1
|
||||
mlx>=0.4
|
||||
numpy
|
||||
transformers>=4.38.0
|
||||
protobuf
|
||||
|
Loading…
Reference in New Issue
Block a user