mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
fixup! feat: Nemotron
This commit is contained in:
@@ -1,30 +0,0 @@
|
||||
"""
|
||||
Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
|
||||
"""
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
class ReLUSquared(nn.Module):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
def __call__(self, input):
|
||||
return nn.relu(input).square()
|
||||
|
||||
ACT2FN = {
|
||||
"gelu": nn.GELU,
|
||||
"mish": nn.Mish,
|
||||
"relu": nn.ReLU,
|
||||
"relu2": ReLUSquared,
|
||||
"relu6": nn.ReLU6,
|
||||
"silu": nn.SiLU,
|
||||
"swish": nn.SiLU,
|
||||
"tanh": nn.Tanh,
|
||||
}
|
||||
|
||||
def get_activation(activation_string):
|
||||
if activation_string in ACT2FN:
|
||||
return ACT2FN[activation_string]
|
||||
else:
|
||||
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
@@ -1,13 +1,11 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.activations import get_activation
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
|
||||
|
||||
@@ -52,6 +50,15 @@ class ModelArgs(BaseModelArgs):
|
||||
)
|
||||
|
||||
|
||||
class ReLUSquared(nn.Module):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
def __call__(self, input):
|
||||
return nn.relu(input).square()
|
||||
|
||||
|
||||
class NemotronLayerNorm1P(nn.LayerNorm):
|
||||
def __call__(self, x):
|
||||
weight = self.weight + 1 if "weight" in self else None
|
||||
@@ -134,7 +141,7 @@ class MLP(nn.Module):
|
||||
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.act_fn = get_activation(args.hidden_act)()
|
||||
self.act_fn = ReLUSquared()
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
Reference in New Issue
Block a user