Shapeless compilation for some graphs (#687)

* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-19 21:43:54 -08:00
committed by GitHub
parent d0fda82595
commit 5798256fcf
14 changed files with 645 additions and 113 deletions

View File

@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from functools import partial
from typing import Any
import mlx.core as mx
@@ -9,13 +10,13 @@ from mlx.nn.layers.base import Module
def _make_activation_module(f):
def decorator(klass):
klass.__doc__ = f.__doc__
klass.__call__ = lambda self, x: f(x)
klass.__call__ = lambda _, x: f(x)
return klass
return decorator
@partial(mx.compile, shapeless=True)
def sigmoid(x):
r"""Applies the element-wise function:
@@ -25,6 +26,7 @@ def sigmoid(x):
return mx.sigmoid(x)
@partial(mx.compile, shapeless=True)
def relu(x):
r"""Applies the Rectified Linear Unit.
@@ -33,6 +35,7 @@ def relu(x):
return mx.maximum(x, 0)
@partial(mx.compile, shapeless=True)
def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit.
@@ -41,6 +44,7 @@ def leaky_relu(x, negative_slope=0.01):
return mx.maximum(negative_slope * x, x)
@partial(mx.compile, shapeless=True)
def log_softmax(x, axis=-1):
r"""Applies the Log Softmax function.
@@ -49,6 +53,7 @@ def log_softmax(x, axis=-1):
return x - mx.logsumexp(x, axis=axis, keepdims=True)
@partial(mx.compile, shapeless=True)
def elu(x, alpha=1.0):
r"""Applies the Exponential Linear Unit.
@@ -57,6 +62,7 @@ def elu(x, alpha=1.0):
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
@@ -65,6 +71,7 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
@@ -73,6 +80,7 @@ def softmax(x, axis=-1):
return mx.softmax(x, axis=axis)
@partial(mx.compile, shapeless=True)
def softplus(x):
r"""Applies the Softplus function.
@@ -81,6 +89,7 @@ def softplus(x):
return mx.logaddexp(x, 0)
@partial(mx.compile, shapeless=True)
def softsign(x):
r"""Applies the Softsign function.
@@ -89,6 +98,7 @@ def softsign(x):
return mx.divide(x, 1 + mx.abs(x))
@partial(mx.compile, shapeless=True)
def softshrink(x, lambd: float = 0.5):
r"""Applies the Softshrink activation function.
@@ -102,6 +112,7 @@ def softshrink(x, lambd: float = 0.5):
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
@partial(mx.compile, shapeless=True)
def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
@@ -111,6 +122,7 @@ def celu(x, alpha=1.0):
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
@partial(mx.compile, shapeless=True)
def silu(x):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
@@ -120,6 +132,7 @@ def silu(x):
return x * mx.sigmoid(x)
@partial(mx.compile, shapeless=True)
def log_sigmoid(x):
r"""Applies the Log Sigmoid function.
@@ -128,6 +141,7 @@ def log_sigmoid(x):
return -softplus(-x)
@partial(mx.compile, shapeless=True)
def gelu(x):
r"""Applies the Gaussian Error Linear Units function.
@@ -142,6 +156,7 @@ def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@partial(mx.compile, shapeless=True)
def gelu_approx(x):
r"""An approximation to Gaussian Error Linear Unit.
@@ -159,6 +174,7 @@ def gelu_approx(x):
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
@partial(mx.compile, shapeless=True)
def gelu_fast_approx(x):
r"""A fast approximation to Gaussian Error Linear Unit.
@@ -192,27 +208,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
return a * mx.sigmoid(b)
class GLU(Module):
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
"""
def __init__(self, axis: int = -1):
super().__init__()
self.axis = axis
def __call__(self, x) -> Any:
return glu(x=x, axis=self.axis)
@partial(mx.compile, shapeless=True)
def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function.
@@ -232,6 +228,7 @@ def step(x: mx.array, threshold: float = 0.0):
return mx.where(x > threshold, 1, 0)
@partial(mx.compile, shapeless=True)
def selu(x):
r"""Applies the Scaled Exponential Linear Unit.
@@ -248,6 +245,7 @@ def selu(x):
return elu(x, 1.67326) * 1.0507
@partial(mx.compile, shapeless=True)
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise parametric ReLU.
@@ -259,6 +257,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
@partial(mx.compile, shapeless=True)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
@@ -272,6 +271,7 @@ def mish(x: mx.array) -> mx.array:
return x * mx.tanh(softplus(x))
@partial(mx.compile, shapeless=True)
def hardswish(x):
r"""Applies the hardswish function, element-wise.
@@ -282,6 +282,35 @@ def hardswish(x):
return x * mx.minimum(max_x_3, 6) / 6
def tanh(x):
"""Applies the hyperbolic tangent function.
Simply ``mx.tanh(x)``.
"""
return mx.tanh(x)
class GLU(Module):
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
"""
def __init__(self, axis: int = -1):
super().__init__()
self.axis = axis
def __call__(self, x) -> Any:
return glu(x=x, axis=self.axis)
@_make_activation_module(mx.sigmoid)
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
@@ -500,14 +529,6 @@ class GELU(Module):
return self._act(x)
def tanh(x):
"""Applies the hyperbolic tangent function.
Simply ``mx.tanh(x)``.
"""
return mx.tanh(x)
@_make_activation_module(tanh)
class Tanh(Module):
r"""Applies the hyperbolic tangent function.