mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00

* implemented instancenorm * implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too * addressed PR review comments * fixed import order in __init__ * expected values in instancenorm tests are simple lists * minor return expression style change * added InstanceNorm to docs * doc string nits * added myself to individual contributors --------- Co-authored-by: Awni Hannun <awni@apple.com>
64 lines
1.2 KiB
Python
64 lines
1.2 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
from mlx.nn.layers.activations import (
|
|
CELU,
|
|
ELU,
|
|
GELU,
|
|
SELU,
|
|
Hardswish,
|
|
LeakyReLU,
|
|
LogSigmoid,
|
|
LogSoftmax,
|
|
Mish,
|
|
PReLU,
|
|
ReLU,
|
|
ReLU6,
|
|
SiLU,
|
|
Softmax,
|
|
Softplus,
|
|
Softsign,
|
|
Step,
|
|
Tanh,
|
|
celu,
|
|
elu,
|
|
gelu,
|
|
gelu_approx,
|
|
gelu_fast_approx,
|
|
hardswish,
|
|
leaky_relu,
|
|
log_sigmoid,
|
|
log_softmax,
|
|
mish,
|
|
prelu,
|
|
relu,
|
|
relu6,
|
|
selu,
|
|
silu,
|
|
softmax,
|
|
softplus,
|
|
softsign,
|
|
step,
|
|
tanh,
|
|
)
|
|
from mlx.nn.layers.base import Module
|
|
from mlx.nn.layers.containers import Sequential
|
|
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
|
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
|
from mlx.nn.layers.embedding import Embedding
|
|
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
|
from mlx.nn.layers.normalization import (
|
|
BatchNorm,
|
|
GroupNorm,
|
|
InstanceNorm,
|
|
LayerNorm,
|
|
RMSNorm,
|
|
)
|
|
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
|
from mlx.nn.layers.quantized import QuantizedLinear
|
|
from mlx.nn.layers.transformer import (
|
|
MultiHeadAttention,
|
|
Transformer,
|
|
TransformerEncoder,
|
|
TransformerEncoderLayer,
|
|
)
|