mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-08 01:54:34 +08:00
fixup! feat: Nemotron
This commit is contained in:
@@ -1,30 +0,0 @@
|
|||||||
"""
|
|
||||||
Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class ReLUSquared(nn.Module):
|
|
||||||
"""
|
|
||||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, input):
|
|
||||||
return nn.relu(input).square()
|
|
||||||
|
|
||||||
ACT2FN = {
|
|
||||||
"gelu": nn.GELU,
|
|
||||||
"mish": nn.Mish,
|
|
||||||
"relu": nn.ReLU,
|
|
||||||
"relu2": ReLUSquared,
|
|
||||||
"relu6": nn.ReLU6,
|
|
||||||
"silu": nn.SiLU,
|
|
||||||
"swish": nn.SiLU,
|
|
||||||
"tanh": nn.Tanh,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_activation(activation_string):
|
|
||||||
if activation_string in ACT2FN:
|
|
||||||
return ACT2FN[activation_string]
|
|
||||||
else:
|
|
||||||
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
|
@@ -1,13 +1,11 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_lm.activations import get_activation
|
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -52,6 +50,15 @@ class ModelArgs(BaseModelArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReLUSquared(nn.Module):
|
||||||
|
"""
|
||||||
|
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
return nn.relu(input).square()
|
||||||
|
|
||||||
|
|
||||||
class NemotronLayerNorm1P(nn.LayerNorm):
|
class NemotronLayerNorm1P(nn.LayerNorm):
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
weight = self.weight + 1 if "weight" in self else None
|
weight = self.weight + 1 if "weight" in self else None
|
||||||
@@ -134,7 +141,7 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
self.act_fn = get_activation(args.hidden_act)()
|
self.act_fn = ReLUSquared()
|
||||||
|
|
||||||
def __call__(self, x) -> mx.array:
|
def __call__(self, x) -> mx.array:
|
||||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||||
|
Reference in New Issue
Block a user