fixup! feat: Nemotron

This commit is contained in:
L Lllvvuu
2024-08-23 00:21:23 +09:00
parent 9c86658aff
commit 01e1f3c9e4
2 changed files with 11 additions and 34 deletions

View File

@@ -1,30 +0,0 @@
"""
Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
"""
import mlx.nn as nn
class ReLUSquared(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def __call__(self, input):
return nn.relu(input).square()
ACT2FN = {
"gelu": nn.GELU,
"mish": nn.Mish,
"relu": nn.ReLU,
"relu2": ReLUSquared,
"relu6": nn.ReLU6,
"silu": nn.SiLU,
"swish": nn.SiLU,
"tanh": nn.Tanh,
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")

View File

@@ -1,13 +1,11 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.activations import get_activation
from .base import BaseModelArgs, KVCache, create_attention_mask
@@ -52,6 +50,15 @@ class ModelArgs(BaseModelArgs):
)
class ReLUSquared(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def __call__(self, input):
return nn.relu(input).square()
class NemotronLayerNorm1P(nn.LayerNorm):
def __call__(self, x):
weight = self.weight + 1 if "weight" in self else None
@@ -134,7 +141,7 @@ class MLP(nn.Module):
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.act_fn = get_activation(args.hidden_act)()
self.act_fn = ReLUSquared()
def __call__(self, x) -> mx.array:
return self.down_proj(self.act_fn(self.up_proj(x)))