Merge branch 'ml-explore:main' into feature_expand_nn_linear

This commit is contained in:
Asaf Zorea 2023-12-31 08:01:15 +02:00 committed by GitHub
commit 7c57ff98a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 183 additions and 14 deletions

View File

@ -6,7 +6,8 @@ with a short description of your contribution(s) below. For example:
- Jane Smith: Added the `foo` and `bar` ops. - Jane Smith: Added the `foo` and `bar` ops.
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added tri, tril, triu and safetensor support - Diogo Da Cruz: Added tri, tril, triu and safetensor support

View File

@ -85,7 +85,7 @@ information on building from source, and running tests.
We are grateful for all of [our We are grateful for all of [our
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to to the list in your to MLX and wish to be acknowledged, please add your name to the list in your
pull request. pull request.
## Citing MLX ## Citing MLX

View File

@ -735,17 +735,18 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
return arr; return arr;
} }
std::vector<int> new_shape(arr.shape()); // Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...)
new_shape[axis] *= repeats; std::vector<int> shape(arr.shape());
shape.insert(shape.begin() + axis + 1, repeats);
array out = expand_dims(arr, axis + 1, s);
out = broadcast_to(out, shape, s);
std::vector<array> repeated_arrays; // Reshape back into a contiguous array where S_axis is now S_axis * repeats
repeated_arrays.reserve(repeats); shape.erase(shape.begin() + axis + 1);
shape[axis] *= repeats;
out = reshape(out, shape, s);
for (int i = 0; i < repeats; ++i) { return out;
repeated_arrays.push_back(expand_dims(arr, -1, s));
}
array repeated = concatenate(repeated_arrays, axis + 1, s);
return reshape(repeated, new_shape, s);
} }
array repeat(const array& arr, int repeats, StreamOrDevice s) { array repeat(const array& arr, int repeats, StreamOrDevice s) {

View File

@ -5,14 +5,18 @@ from mlx.nn.layers.activations import (
ELU, ELU,
GELU, GELU,
SELU, SELU,
Hardswish,
LeakyReLU, LeakyReLU,
LogSigmoid, LogSigmoid,
LogSoftmax,
Mish, Mish,
PReLU, PReLU,
ReLU, ReLU,
ReLU6, ReLU6,
SiLU, SiLU,
Softmax,
Softplus, Softplus,
Softsign,
Step, Step,
Tanh, Tanh,
celu, celu,
@ -20,15 +24,19 @@ from mlx.nn.layers.activations import (
gelu, gelu,
gelu_approx, gelu_approx,
gelu_fast_approx, gelu_fast_approx,
hardswish,
leaky_relu, leaky_relu,
log_sigmoid, log_sigmoid,
log_softmax,
mish, mish,
prelu, prelu,
relu, relu,
relu6, relu6,
selu, selu,
silu, silu,
softmax,
softplus, softplus,
softsign,
step, step,
tanh, tanh,
) )

View File

@ -25,7 +25,7 @@ def sigmoid(x):
def relu(x): def relu(x):
"""Applies the Rectified Linear Unit. r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``. Simply ``mx.maximum(x, 0)``.
""" """
@ -33,15 +33,23 @@ def relu(x):
def leaky_relu(x, negative_slope=0.01): def leaky_relu(x, negative_slope=0.01):
"""Applies the Leaky Rectified Linear Unit. r"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``. Simply ``mx.maximum(negative_slope * x, x)``.
""" """
return mx.maximum(negative_slope * x, x) return mx.maximum(negative_slope * x, x)
def log_softmax(x, axis=-1):
r"""Applies the Log Softmax function.
Applies :math:`x + \log \sum_i e^{x_i}` element wise.
"""
return x - mx.logsumexp(x, axis=axis, keepdims=True)
def elu(x, alpha=1.0): def elu(x, alpha=1.0):
"""Applies the Exponential Linear Unit. r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``. Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
""" """
@ -56,6 +64,14 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0) return mx.minimum(mx.maximum(x, 0), 6.0)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise.
"""
return mx.softmax(x, axis=axis)
def softplus(x): def softplus(x):
r"""Applies the Softplus function. r"""Applies the Softplus function.
@ -64,6 +80,14 @@ def softplus(x):
return mx.logaddexp(x, 0) return mx.logaddexp(x, 0)
def softsign(x):
r"""Applies the Softsign function.
Applies :math:`\frac{x}{1 + |x|}` element wise.
"""
return mx.divide(x, 1 + mx.abs(x))
def celu(x, alpha=1.0): def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit. r"""Applies the Continuously Differentiable Exponential Linear Unit.
@ -140,6 +164,11 @@ def gelu_fast_approx(x):
@_make_activation_module @_make_activation_module
class Sigmoid(Module): class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
pass pass
@ -202,13 +231,36 @@ def mish(x: mx.array) -> mx.array:
return x * mx.tanh(softplus(x)) return x * mx.tanh(softplus(x))
def hardswish(x):
r"""Applies the hardswish function, element-wise.
.. math::
\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6
"""
max_x_3 = mx.maximum(x + 3, 0)
return x * mx.minimum(max_x_3, 6) / 6
@_make_activation_module(mish) @_make_activation_module(mish)
class Mish(Module): class Mish(Module):
r"""Applies the Mish function, element-wise.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
pass pass
@_make_activation_module(relu) @_make_activation_module(relu)
class ReLU(Module): class ReLU(Module):
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
See :func:`relu`, for the functional equivalent.
"""
pass pass
@ -249,11 +301,37 @@ class ELU(Module):
@_make_activation_module(relu6) @_make_activation_module(relu6)
class ReLU6(Module): class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6`, for the functional equivalent.
"""
pass
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.
See :func:`softmax`, for the functional equivalent.
"""
pass pass
@_make_activation_module(softplus) @_make_activation_module(softplus)
class Softplus(Module): class Softplus(Module):
r"""Applies the Softplus function.
See :func:`softplus`, for the functional equivalent.
"""
pass
@_make_activation_module(softsign)
class Softsign(Module):
r"""Applies the Softsign function.
See :func:`softsign`, for the functional equivalent.
"""
pass pass
@ -278,15 +356,43 @@ class CELU(Module):
@_make_activation_module(silu) @_make_activation_module(silu)
class SiLU(Module): class SiLU(Module):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
See :func:`silu`, for the functional equivalent.
"""
pass
@_make_activation_module(log_softmax)
class LogSoftmax(Module):
r"""Applies the Log Softmax function.
See :func:`log_softmax`, for the functional equivalent.
"""
pass pass
@_make_activation_module(log_sigmoid) @_make_activation_module(log_sigmoid)
class LogSigmoid(Module): class LogSigmoid(Module):
r"""Applies the Log Sigmoid function.
See :func:`log_sigmoid`, for the functional equivalent.
"""
pass pass
class PReLU(Module): class PReLU(Module):
r"""Applies the element-wise parametric ReLU.
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
is an array.
See :func:`prelu`, for the functional equivalent.
Args:
num_parameters: number of :math:`a` to learn. Default: 1
init: the initial value of :math:`a`. Default: 0.25
"""
def __init__(self, num_parameters=1, init=0.25): def __init__(self, num_parameters=1, init=0.25):
super().__init__() super().__init__()
self.weight = mx.full([num_parameters], init) self.weight = mx.full([num_parameters], init)
@ -346,6 +452,19 @@ def tanh(x):
@_make_activation_module(tanh) @_make_activation_module(tanh)
class Tanh(Module): class Tanh(Module):
r"""Applies the hyperbolic tangent function.
See :func:`tanh`, for the functional equivalent.
"""
pass
@_make_activation_module(hardswish)
class Hardswish(Module):
r"""Applies the hardswish function, element-wise.
See :func:`hardswish`, for the functional equivalent.
"""
pass pass
@ -375,4 +494,8 @@ class Step(Module):
@_make_activation_module(selu) @_make_activation_module(selu)
class SELU(Module): class SELU(Module):
r"""Applies the Scaled Exponential Linear Unit.
See :func:`selu`, for the functional equivalent.
"""
pass pass

View File

@ -680,6 +680,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [5]) self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_softmax(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softmax(x)
epsilon = 1e-4
expected_y = mx.array([0.6652, 0.0900, 0.2447])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self): def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x) y = nn.softplus(x)
@ -689,6 +698,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_softsign(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softsign(x)
epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_celu(self): def test_celu(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.celu(x) y = nn.celu(x)
@ -704,6 +722,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_log_softmax(self):
x = mx.array([1.0, 2.0, 3.0])
y = nn.log_softmax(x)
epsilon = 1e-4
expected_y = mx.array([-2.4076, -1.4076, -0.4076])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_log_sigmoid(self): def test_log_sigmoid(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.log_sigmoid(x) y = nn.log_sigmoid(x)
@ -725,6 +752,15 @@ class TestNN(mlx_tests.MLXTestCase):
mx.array([0.8651, -0.3034, 0.0000, 0.3752]), mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
) )
def test_hardswish(self):
x = mx.array([-3.0, -1.5, 0.0, 1.5, 3.0])
y = nn.hardswish(x)
epsilon = 1e-4
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32)
def test_rope(self): def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}]: for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
rope = nn.RoPE(4, **kwargs) rope = nn.RoPE(4, **kwargs)