Shapeless compilation for some graphs (#687)

* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-19 21:43:54 -08:00
committed by GitHub
parent d0fda82595
commit 5798256fcf
14 changed files with 645 additions and 113 deletions

View File

@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from functools import partial
from typing import Any
import mlx.core as mx
@@ -9,13 +10,13 @@ from mlx.nn.layers.base import Module
def _make_activation_module(f):
def decorator(klass):
klass.__doc__ = f.__doc__
klass.__call__ = lambda self, x: f(x)
klass.__call__ = lambda _, x: f(x)
return klass
return decorator
@partial(mx.compile, shapeless=True)
def sigmoid(x):
r"""Applies the element-wise function:
@@ -25,6 +26,7 @@ def sigmoid(x):
return mx.sigmoid(x)
@partial(mx.compile, shapeless=True)
def relu(x):
r"""Applies the Rectified Linear Unit.
@@ -33,6 +35,7 @@ def relu(x):
return mx.maximum(x, 0)
@partial(mx.compile, shapeless=True)
def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit.
@@ -41,6 +44,7 @@ def leaky_relu(x, negative_slope=0.01):
return mx.maximum(negative_slope * x, x)
@partial(mx.compile, shapeless=True)
def log_softmax(x, axis=-1):
r"""Applies the Log Softmax function.
@@ -49,6 +53,7 @@ def log_softmax(x, axis=-1):
return x - mx.logsumexp(x, axis=axis, keepdims=True)
@partial(mx.compile, shapeless=True)
def elu(x, alpha=1.0):
r"""Applies the Exponential Linear Unit.
@@ -57,6 +62,7 @@ def elu(x, alpha=1.0):
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
@@ -65,6 +71,7 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
@@ -73,6 +80,7 @@ def softmax(x, axis=-1):
return mx.softmax(x, axis=axis)
@partial(mx.compile, shapeless=True)
def softplus(x):
r"""Applies the Softplus function.
@@ -81,6 +89,7 @@ def softplus(x):
return mx.logaddexp(x, 0)
@partial(mx.compile, shapeless=True)
def softsign(x):
r"""Applies the Softsign function.
@@ -89,6 +98,7 @@ def softsign(x):
return mx.divide(x, 1 + mx.abs(x))
@partial(mx.compile, shapeless=True)
def softshrink(x, lambd: float = 0.5):
r"""Applies the Softshrink activation function.
@@ -102,6 +112,7 @@ def softshrink(x, lambd: float = 0.5):
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
@partial(mx.compile, shapeless=True)
def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
@@ -111,6 +122,7 @@ def celu(x, alpha=1.0):
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
@partial(mx.compile, shapeless=True)
def silu(x):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
@@ -120,6 +132,7 @@ def silu(x):
return x * mx.sigmoid(x)
@partial(mx.compile, shapeless=True)
def log_sigmoid(x):
r"""Applies the Log Sigmoid function.
@@ -128,6 +141,7 @@ def log_sigmoid(x):
return -softplus(-x)
@partial(mx.compile, shapeless=True)
def gelu(x):
r"""Applies the Gaussian Error Linear Units function.
@@ -142,6 +156,7 @@ def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@partial(mx.compile, shapeless=True)
def gelu_approx(x):
r"""An approximation to Gaussian Error Linear Unit.
@@ -159,6 +174,7 @@ def gelu_approx(x):
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
@partial(mx.compile, shapeless=True)
def gelu_fast_approx(x):
r"""A fast approximation to Gaussian Error Linear Unit.
@@ -192,27 +208,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
return a * mx.sigmoid(b)
class GLU(Module):
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
"""
def __init__(self, axis: int = -1):
super().__init__()
self.axis = axis
def __call__(self, x) -> Any:
return glu(x=x, axis=self.axis)
@partial(mx.compile, shapeless=True)
def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function.
@@ -232,6 +228,7 @@ def step(x: mx.array, threshold: float = 0.0):
return mx.where(x > threshold, 1, 0)
@partial(mx.compile, shapeless=True)
def selu(x):
r"""Applies the Scaled Exponential Linear Unit.
@@ -248,6 +245,7 @@ def selu(x):
return elu(x, 1.67326) * 1.0507
@partial(mx.compile, shapeless=True)
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise parametric ReLU.
@@ -259,6 +257,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
@partial(mx.compile, shapeless=True)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
@@ -272,6 +271,7 @@ def mish(x: mx.array) -> mx.array:
return x * mx.tanh(softplus(x))
@partial(mx.compile, shapeless=True)
def hardswish(x):
r"""Applies the hardswish function, element-wise.
@@ -282,6 +282,35 @@ def hardswish(x):
return x * mx.minimum(max_x_3, 6) / 6
def tanh(x):
"""Applies the hyperbolic tangent function.
Simply ``mx.tanh(x)``.
"""
return mx.tanh(x)
class GLU(Module):
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
"""
def __init__(self, axis: int = -1):
super().__init__()
self.axis = axis
def __call__(self, x) -> Any:
return glu(x=x, axis=self.axis)
@_make_activation_module(mx.sigmoid)
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
@@ -500,14 +529,6 @@ class GELU(Module):
return self._act(x)
def tanh(x):
"""Applies the hyperbolic tangent function.
Simply ``mx.tanh(x)``.
"""
return mx.tanh(x)
@_make_activation_module(tanh)
class Tanh(Module):
r"""Applies the hyperbolic tangent function.

View File

@@ -555,13 +555,19 @@ struct PyCompiledFun {
size_t fun_id;
py::object captured_inputs;
py::object captured_outputs;
bool shapeless;
size_t num_outputs{0};
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
PyCompiledFun(
const py::function& fun,
py::object inputs,
py::object outputs,
bool shapeless)
: fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())),
captured_inputs(inputs),
captured_outputs(outputs) {}
captured_outputs(outputs),
shapeless(shapeless) {}
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
@@ -571,11 +577,15 @@ struct PyCompiledFun {
other.fun_id = 0;
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);
shapeless = other.shapeless;
num_outputs = other.num_outputs;
};
py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto inputs = tree_flatten(args, false);
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
const std::vector<array>& a) {
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
std::vector<array> trace_captures;
@@ -586,8 +596,10 @@ struct PyCompiledFun {
tree_fill(captured_inputs, trace_captures);
}
auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(fun(*tree_unflatten(args, a))), false);
auto tree_outputs =
fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args));
auto [outputs, py_outputs] =
tree_flatten_with_structure(std::move(tree_outputs), false);
tree_cache().insert({fun_id, py_outputs});
@@ -607,7 +619,14 @@ struct PyCompiledFun {
return outputs;
};
auto inputs = tree_flatten(args, false);
{
auto flat_kwargs = tree_flatten(kwargs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_kwargs.begin()),
std::make_move_iterator(flat_kwargs.end()));
}
if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
@@ -616,8 +635,39 @@ struct PyCompiledFun {
std::make_move_iterator(flat_in_captures.end()));
}
// Collect the compilation constants
std::vector<uint64_t> constants;
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
// Consider expanding tuples to their contents including start and end
// ids
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
auto r = py::hash(o);
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::int_>(o)) {
auto r = o.cast<int64_t>();
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::float_>(o)) {
auto r = o.cast<double>();
return *reinterpret_cast<uint64_t*>(&r);
} else {
return std::nullopt;
}
};
for (int i = 0; i < args.size(); i++) {
if (auto h = value_hash(args[i]); h.has_value()) {
constants.push_back(*h);
}
}
for (auto& pair : kwargs) {
if (auto h = value_hash(pair.second); h.has_value()) {
constants.push_back(*value_hash(pair.first));
constants.push_back(*h);
}
}
// Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
if (!py::isinstance<py::none>(captured_outputs)) {
std::vector<array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
@@ -965,12 +1015,14 @@ void init_transforms(py::module_& m) {
"compile",
[](const py::function& fun,
const py::object& inputs,
const py::object& outputs) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
const py::object& outputs,
bool shapeless) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
},
"fun"_a,
"inputs"_a = std::nullopt,
"outputs"_a = std::nullopt,
"shapeless"_a = false,
R"pbdoc(
compile(fun: function) -> function
@@ -990,6 +1042,12 @@ void init_transforms(py::module_& m) {
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
Default: ``None``
shapeless (bool, optional): A function compiled with the ``shapeless``
option enabled will not be recompiled when the input shape changes. Not all
functions can be compiled with ``shapeless`` enabled. Attempting to compile
such functions with shapeless enabled will throw. Note, changing the number
of dimensions or type of any input will result in a recompilation even with
``shapeless`` set to ``True``. Default: ``False``
Returns:
function: A compiled function which has the same input arguments

View File

@@ -381,6 +381,164 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
def test_compile_kwargs(self):
@mx.compile
def fun(x, y, z):
return x + y + z
x = mx.array(1)
y = mx.array(2)
z = mx.array(3)
out = fun(x, y=y, z=z)
self.assertEqual(out.item(), 6)
def test_shapeless_compile(self):
y = 1
@partial(mx.compile, shapeless=True)
def fun(x):
return x + y
x = mx.array([1, 2])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
# The function is not recompiled, so the change
# to y should not be reflected in the output
y = 2
x = mx.array([1, 2, 3])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
# Type change recompiles
x = mx.array([1.0, 2.0, 3.0])
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
fun(x, y=y, z=z)
def test_shapeless_compile(self):
y = 1
@partial(mx.compile, shapeless=True)
def fun(x):
return x + y
x = mx.array([1, 2])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
# The function is not recompiled, so the change
# to y should not be reflected in the output
y = 2
x = mx.array([1, 2, 3])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
# Type change recompiles
x = mx.array([1.0, 2.0, 3.0])
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
# Dim change recompiles
x = mx.array([[1, 2, 3]])
self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]])))
def test_shapeless_compile_with_broadcasts(self):
x = mx.ones((2, 2))
y = mx.array([2, 2])
def fun(x, y):
return x * y
cfun = mx.compile(fun, shapeless=True)
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
y = mx.array([[3]])
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
def test_shapeless_compile_with_reduction(self):
# Test shapeless compile with a reduction
z = 1
@partial(mx.compile, shapeless=True)
def fun(x, y):
return x + y.sum(0, keepdims=True) + z
x = mx.ones((2, 2), mx.int32)
y = mx.ones((2, 2), mx.int32)
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4)))
x = mx.ones((3, 3), mx.int32)
y = mx.ones((3, 3), mx.int32)
z = 2
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5)))
x1 = mx.array([[1, 2], [3, 4], [5, 6]])
x2 = mx.array([[1, 2]])
def fun(x):
return x * x.sum(-1, keepdims=True)
cfun = mx.compile(fun, shapeless=True)
mx.eval(cfun(x1))
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
def fun(x, y):
return x + y
z = fun(mx.array(1.0), 1.0)
self.assertEqual(z.item(), 2.0)
z = fun(mx.array(1.0), 2.0)
self.assertEqual(z.item(), 3.0)
z = fun(mx.array(1.0), y=1.0)
self.assertEqual(z.item(), 2.0)
z = fun(mx.array(1.0), y=3.0)
self.assertEqual(z.item(), 4.0)
# Test tuple
@partial(mx.compile)
def fun(x, y=(1, 2)):
return x + y[0] + y[1]
z = fun(mx.array(1))
self.assertEqual(z.item(), 4)
z = fun(mx.array(1), (2, 2))
self.assertEqual(z.item(), 5)
z = fun(mx.array(1), (2, 1))
self.assertEqual(z.item(), 4)
# Test bool
@partial(mx.compile)
def fun(x, y):
if y:
return x + 1
else:
return x + 2
z = fun(mx.array(1), True)
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), False)
self.assertEqual(z.item(), 3)
# Test string
@partial(mx.compile)
def fun(x, y):
if y == "one":
return x + 1
else:
return x + 2
z = fun(mx.array(1), "one")
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), "two")
self.assertEqual(z.item(), 3)
if __name__ == "__main__":
unittest.main()