fixup! feat: Nemotron

This commit is contained in:
L Lllvvuu
2024-08-23 00:21:23 +09:00
parent 9c86658aff
commit 01e1f3c9e4
2 changed files with 11 additions and 34 deletions

View File

@@ -1,30 +0,0 @@
"""
Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
"""
import mlx.nn as nn
class ReLUSquared(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def __call__(self, input):
return nn.relu(input).square()
ACT2FN = {
"gelu": nn.GELU,
"mish": nn.Mish,
"relu": nn.ReLU,
"relu2": ReLUSquared,
"relu6": nn.ReLU6,
"silu": nn.SiLU,
"swish": nn.SiLU,
"tanh": nn.Tanh,
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")

View File

@@ -1,13 +1,11 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_lm.activations import get_activation
from .base import BaseModelArgs, KVCache, create_attention_mask from .base import BaseModelArgs, KVCache, create_attention_mask
@@ -52,6 +50,15 @@ class ModelArgs(BaseModelArgs):
) )
class ReLUSquared(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def __call__(self, input):
return nn.relu(input).square()
class NemotronLayerNorm1P(nn.LayerNorm): class NemotronLayerNorm1P(nn.LayerNorm):
def __call__(self, x): def __call__(self, x):
weight = self.weight + 1 if "weight" in self else None weight = self.weight + 1 if "weight" in self else None
@@ -134,7 +141,7 @@ class MLP(nn.Module):
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.act_fn = get_activation(args.hidden_act)() self.act_fn = ReLUSquared()
def __call__(self, x) -> mx.array: def __call__(self, x) -> mx.array:
return self.down_proj(self.act_fn(self.up_proj(x))) return self.down_proj(self.act_fn(self.up_proj(x)))