Merge branch 'ml-explore:main' into feature_expand_nn_linear

This commit is contained in:
Asaf Zorea 2023-12-31 08:01:15 +02:00 committed by GitHub
commit 7c57ff98a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 183 additions and 14 deletions

View File

@ -6,7 +6,8 @@ with a short description of your contribution(s) below. For example:
- Jane Smith: Added the `foo` and `bar` ops.
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added tri, tril, triu and safetensor support

View File

@ -85,7 +85,7 @@ information on building from source, and running tests.
We are grateful for all of [our
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to to the list in your
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.
## Citing MLX

View File

@ -735,17 +735,18 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
return arr;
}
std::vector<int> new_shape(arr.shape());
new_shape[axis] *= repeats;
// Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)
std::vector<int> shape(arr.shape());
shape.insert(shape.begin() + axis + 1, repeats);
array out = expand_dims(arr, axis + 1, s);
out = broadcast_to(out, shape, s);
std::vector<array> repeated_arrays;
repeated_arrays.reserve(repeats);
// Reshape back into a contiguous array where S_axis is now S_axis * repeats
shape.erase(shape.begin() + axis + 1);
shape[axis] *= repeats;
out = reshape(out, shape, s);
for (int i = 0; i < repeats; ++i) {
repeated_arrays.push_back(expand_dims(arr, -1, s));
}
array repeated = concatenate(repeated_arrays, axis + 1, s);
return reshape(repeated, new_shape, s);
return out;
}
array repeat(const array& arr, int repeats, StreamOrDevice s) {

View File

@ -5,14 +5,18 @@ from mlx.nn.layers.activations import (
ELU,
GELU,
SELU,
Hardswish,
LeakyReLU,
LogSigmoid,
LogSoftmax,
Mish,
PReLU,
ReLU,
ReLU6,
SiLU,
Softmax,
Softplus,
Softsign,
Step,
Tanh,
celu,
@ -20,15 +24,19 @@ from mlx.nn.layers.activations import (
gelu,
gelu_approx,
gelu_fast_approx,
hardswish,
leaky_relu,
log_sigmoid,
log_softmax,
mish,
prelu,
relu,
relu6,
selu,
silu,
softmax,
softplus,
softsign,
step,
tanh,
)

View File

@ -25,7 +25,7 @@ def sigmoid(x):
def relu(x):
"""Applies the Rectified Linear Unit.
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
"""
@ -33,15 +33,23 @@ def relu(x):
def leaky_relu(x, negative_slope=0.01):
"""Applies the Leaky Rectified Linear Unit.
r"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``.
"""
return mx.maximum(negative_slope * x, x)
def log_softmax(x, axis=-1):
r"""Applies the Log Softmax function.
Applies :math:`x + \log \sum_i e^{x_i}` element wise.
"""
return x - mx.logsumexp(x, axis=axis, keepdims=True)
def elu(x, alpha=1.0):
"""Applies the Exponential Linear Unit.
r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
"""
@ -56,6 +64,14 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise.
"""
return mx.softmax(x, axis=axis)
def softplus(x):
r"""Applies the Softplus function.
@ -64,6 +80,14 @@ def softplus(x):
return mx.logaddexp(x, 0)
def softsign(x):
r"""Applies the Softsign function.
Applies :math:`\frac{x}{1 + |x|}` element wise.
"""
return mx.divide(x, 1 + mx.abs(x))
def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
@ -140,6 +164,11 @@ def gelu_fast_approx(x):
@_make_activation_module
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
pass
@ -202,13 +231,36 @@ def mish(x: mx.array) -> mx.array:
return x * mx.tanh(softplus(x))
def hardswish(x):
r"""Applies the hardswish function, element-wise.
.. math::
\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6
"""
max_x_3 = mx.maximum(x + 3, 0)
return x * mx.minimum(max_x_3, 6) / 6
@_make_activation_module(mish)
class Mish(Module):
r"""Applies the Mish function, element-wise.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
pass
@_make_activation_module(relu)
class ReLU(Module):
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
See :func:`relu`, for the functional equivalent.
"""
pass
@ -249,11 +301,37 @@ class ELU(Module):
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6`, for the functional equivalent.
"""
pass
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.
See :func:`softmax`, for the functional equivalent.
"""
pass
@_make_activation_module(softplus)
class Softplus(Module):
r"""Applies the Softplus function.
See :func:`softplus`, for the functional equivalent.
"""
pass
@_make_activation_module(softsign)
class Softsign(Module):
r"""Applies the Softsign function.
See :func:`softsign`, for the functional equivalent.
"""
pass
@ -278,15 +356,43 @@ class CELU(Module):
@_make_activation_module(silu)
class SiLU(Module):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
See :func:`silu`, for the functional equivalent.
"""
pass
@_make_activation_module(log_softmax)
class LogSoftmax(Module):
r"""Applies the Log Softmax function.
See :func:`log_softmax`, for the functional equivalent.
"""
pass
@_make_activation_module(log_sigmoid)
class LogSigmoid(Module):
r"""Applies the Log Sigmoid function.
See :func:`log_sigmoid`, for the functional equivalent.
"""
pass
class PReLU(Module):
r"""Applies the element-wise parametric ReLU.
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
is an array.
See :func:`prelu`, for the functional equivalent.
Args:
num_parameters: number of :math:`a` to learn. Default: 1
init: the initial value of :math:`a`. Default: 0.25
"""
def __init__(self, num_parameters=1, init=0.25):
super().__init__()
self.weight = mx.full([num_parameters], init)
@ -346,6 +452,19 @@ def tanh(x):
@_make_activation_module(tanh)
class Tanh(Module):
r"""Applies the hyperbolic tangent function.
See :func:`tanh`, for the functional equivalent.
"""
pass
@_make_activation_module(hardswish)
class Hardswish(Module):
r"""Applies the hardswish function, element-wise.
See :func:`hardswish`, for the functional equivalent.
"""
pass
@ -375,4 +494,8 @@ class Step(Module):
@_make_activation_module(selu)
class SELU(Module):
r"""Applies the Scaled Exponential Linear Unit.
See :func:`selu`, for the functional equivalent.
"""
pass

View File

@ -680,6 +680,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32)
def test_softmax(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softmax(x)
epsilon = 1e-4
expected_y = mx.array([0.6652, 0.0900, 0.2447])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x)
@ -689,6 +698,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_softsign(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softsign(x)
epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_celu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.celu(x)
@ -704,6 +722,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_log_softmax(self):
x = mx.array([1.0, 2.0, 3.0])
y = nn.log_softmax(x)
epsilon = 1e-4
expected_y = mx.array([-2.4076, -1.4076, -0.4076])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_log_sigmoid(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.log_sigmoid(x)
@ -725,6 +752,15 @@ class TestNN(mlx_tests.MLXTestCase):
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
)
def test_hardswish(self):
x = mx.array([-3.0, -1.5, 0.0, 1.5, 3.0])
y = nn.hardswish(x)
epsilon = 1e-4
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
rope = nn.RoPE(4, **kwargs)