Files
mlx-examples/llms/mlx_lm/activations.py
L Lllvvuu 9c86658aff feat: Nemotron
https://huggingface.co/nvidia/Minitron-4B-Base

This is basically Llama with partial RoPE and LayerNorm instead of
BatchNorm. Also they add 1 to the LayerNorm weight for some reason.
2024-08-21 18:46:35 +09:00

31 lines
749 B
Python

"""
Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
"""
import mlx.nn as nn
class ReLUSquared(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def __call__(self, input):
return nn.relu(input).square()
ACT2FN = {
"gelu": nn.GELU,
"mish": nn.Mish,
"relu": nn.ReLU,
"relu2": ReLUSquared,
"relu6": nn.ReLU6,
"silu": nn.SiLU,
"swish": nn.SiLU,
"tanh": nn.Tanh,
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")