mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -9,13 +10,13 @@ from mlx.nn.layers.base import Module
|
||||
|
||||
def _make_activation_module(f):
|
||||
def decorator(klass):
|
||||
klass.__doc__ = f.__doc__
|
||||
klass.__call__ = lambda self, x: f(x)
|
||||
klass.__call__ = lambda _, x: f(x)
|
||||
return klass
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def sigmoid(x):
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
@@ -25,6 +26,7 @@ def sigmoid(x):
|
||||
return mx.sigmoid(x)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu(x):
|
||||
r"""Applies the Rectified Linear Unit.
|
||||
|
||||
@@ -33,6 +35,7 @@ def relu(x):
|
||||
return mx.maximum(x, 0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def leaky_relu(x, negative_slope=0.01):
|
||||
r"""Applies the Leaky Rectified Linear Unit.
|
||||
|
||||
@@ -41,6 +44,7 @@ def leaky_relu(x, negative_slope=0.01):
|
||||
return mx.maximum(negative_slope * x, x)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def log_softmax(x, axis=-1):
|
||||
r"""Applies the Log Softmax function.
|
||||
|
||||
@@ -49,6 +53,7 @@ def log_softmax(x, axis=-1):
|
||||
return x - mx.logsumexp(x, axis=axis, keepdims=True)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def elu(x, alpha=1.0):
|
||||
r"""Applies the Exponential Linear Unit.
|
||||
|
||||
@@ -57,6 +62,7 @@ def elu(x, alpha=1.0):
|
||||
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu6(x):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
@@ -65,6 +71,7 @@ def relu6(x):
|
||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softmax(x, axis=-1):
|
||||
r"""Applies the Softmax function.
|
||||
|
||||
@@ -73,6 +80,7 @@ def softmax(x, axis=-1):
|
||||
return mx.softmax(x, axis=axis)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softplus(x):
|
||||
r"""Applies the Softplus function.
|
||||
|
||||
@@ -81,6 +89,7 @@ def softplus(x):
|
||||
return mx.logaddexp(x, 0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softsign(x):
|
||||
r"""Applies the Softsign function.
|
||||
|
||||
@@ -89,6 +98,7 @@ def softsign(x):
|
||||
return mx.divide(x, 1 + mx.abs(x))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softshrink(x, lambd: float = 0.5):
|
||||
r"""Applies the Softshrink activation function.
|
||||
|
||||
@@ -102,6 +112,7 @@ def softshrink(x, lambd: float = 0.5):
|
||||
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def celu(x, alpha=1.0):
|
||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||
|
||||
@@ -111,6 +122,7 @@ def celu(x, alpha=1.0):
|
||||
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def silu(x):
|
||||
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
||||
|
||||
@@ -120,6 +132,7 @@ def silu(x):
|
||||
return x * mx.sigmoid(x)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def log_sigmoid(x):
|
||||
r"""Applies the Log Sigmoid function.
|
||||
|
||||
@@ -128,6 +141,7 @@ def log_sigmoid(x):
|
||||
return -softplus(-x)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gelu(x):
|
||||
r"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
@@ -142,6 +156,7 @@ def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gelu_approx(x):
|
||||
r"""An approximation to Gaussian Error Linear Unit.
|
||||
|
||||
@@ -159,6 +174,7 @@ def gelu_approx(x):
|
||||
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gelu_fast_approx(x):
|
||||
r"""A fast approximation to Gaussian Error Linear Unit.
|
||||
|
||||
@@ -192,27 +208,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
return a * mx.sigmoid(b)
|
||||
|
||||
|
||||
class GLU(Module):
|
||||
r"""Applies the gated linear unit function.
|
||||
|
||||
This function splits the ``axis`` dimension of the input into two halves
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
"""
|
||||
|
||||
def __init__(self, axis: int = -1):
|
||||
super().__init__()
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, x) -> Any:
|
||||
return glu(x=x, axis=self.axis)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def step(x: mx.array, threshold: float = 0.0):
|
||||
r"""Applies the Step Activation Function.
|
||||
|
||||
@@ -232,6 +228,7 @@ def step(x: mx.array, threshold: float = 0.0):
|
||||
return mx.where(x > threshold, 1, 0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def selu(x):
|
||||
r"""Applies the Scaled Exponential Linear Unit.
|
||||
|
||||
@@ -248,6 +245,7 @@ def selu(x):
|
||||
return elu(x, 1.67326) * 1.0507
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
r"""Applies the element-wise parametric ReLU.
|
||||
|
||||
@@ -259,6 +257,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def mish(x: mx.array) -> mx.array:
|
||||
r"""Applies the Mish function, element-wise.
|
||||
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||
@@ -272,6 +271,7 @@ def mish(x: mx.array) -> mx.array:
|
||||
return x * mx.tanh(softplus(x))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def hardswish(x):
|
||||
r"""Applies the hardswish function, element-wise.
|
||||
|
||||
@@ -282,6 +282,35 @@ def hardswish(x):
|
||||
return x * mx.minimum(max_x_3, 6) / 6
|
||||
|
||||
|
||||
def tanh(x):
|
||||
"""Applies the hyperbolic tangent function.
|
||||
|
||||
Simply ``mx.tanh(x)``.
|
||||
"""
|
||||
return mx.tanh(x)
|
||||
|
||||
|
||||
class GLU(Module):
|
||||
r"""Applies the gated linear unit function.
|
||||
|
||||
This function splits the ``axis`` dimension of the input into two halves
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
"""
|
||||
|
||||
def __init__(self, axis: int = -1):
|
||||
super().__init__()
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, x) -> Any:
|
||||
return glu(x=x, axis=self.axis)
|
||||
|
||||
|
||||
@_make_activation_module(mx.sigmoid)
|
||||
class Sigmoid(Module):
|
||||
r"""Applies the sigmoid function, element-wise.
|
||||
@@ -500,14 +529,6 @@ class GELU(Module):
|
||||
return self._act(x)
|
||||
|
||||
|
||||
def tanh(x):
|
||||
"""Applies the hyperbolic tangent function.
|
||||
|
||||
Simply ``mx.tanh(x)``.
|
||||
"""
|
||||
return mx.tanh(x)
|
||||
|
||||
|
||||
@_make_activation_module(tanh)
|
||||
class Tanh(Module):
|
||||
r"""Applies the hyperbolic tangent function.
|
||||
|
@@ -555,13 +555,19 @@ struct PyCompiledFun {
|
||||
size_t fun_id;
|
||||
py::object captured_inputs;
|
||||
py::object captured_outputs;
|
||||
bool shapeless;
|
||||
size_t num_outputs{0};
|
||||
|
||||
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
|
||||
PyCompiledFun(
|
||||
const py::function& fun,
|
||||
py::object inputs,
|
||||
py::object outputs,
|
||||
bool shapeless)
|
||||
: fun(fun),
|
||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||
captured_inputs(inputs),
|
||||
captured_outputs(outputs) {}
|
||||
captured_outputs(outputs),
|
||||
shapeless(shapeless) {}
|
||||
|
||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
@@ -571,11 +577,15 @@ struct PyCompiledFun {
|
||||
other.fun_id = 0;
|
||||
captured_inputs = std::move(other.captured_inputs);
|
||||
captured_outputs = std::move(other.captured_outputs);
|
||||
shapeless = other.shapeless;
|
||||
num_outputs = other.num_outputs;
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
auto inputs = tree_flatten(args, false);
|
||||
|
||||
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
|
||||
const std::vector<array>& a) {
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
std::vector<array> trace_captures;
|
||||
@@ -586,8 +596,10 @@ struct PyCompiledFun {
|
||||
tree_fill(captured_inputs, trace_captures);
|
||||
}
|
||||
|
||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
||||
std::move(fun(*tree_unflatten(args, a))), false);
|
||||
auto tree_outputs =
|
||||
fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args));
|
||||
auto [outputs, py_outputs] =
|
||||
tree_flatten_with_structure(std::move(tree_outputs), false);
|
||||
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
|
||||
@@ -607,7 +619,14 @@ struct PyCompiledFun {
|
||||
return outputs;
|
||||
};
|
||||
|
||||
auto inputs = tree_flatten(args, false);
|
||||
{
|
||||
auto flat_kwargs = tree_flatten(kwargs, false);
|
||||
inputs.insert(
|
||||
inputs.end(),
|
||||
std::make_move_iterator(flat_kwargs.begin()),
|
||||
std::make_move_iterator(flat_kwargs.end()));
|
||||
}
|
||||
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
inputs.insert(
|
||||
@@ -616,8 +635,39 @@ struct PyCompiledFun {
|
||||
std::make_move_iterator(flat_in_captures.end()));
|
||||
}
|
||||
|
||||
// Collect the compilation constants
|
||||
std::vector<uint64_t> constants;
|
||||
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
|
||||
// Consider expanding tuples to their contents including start and end
|
||||
// ids
|
||||
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
|
||||
auto r = py::hash(o);
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else if (py::isinstance<py::int_>(o)) {
|
||||
auto r = o.cast<int64_t>();
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else if (py::isinstance<py::float_>(o)) {
|
||||
auto r = o.cast<double>();
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
if (auto h = value_hash(args[i]); h.has_value()) {
|
||||
constants.push_back(*h);
|
||||
}
|
||||
}
|
||||
for (auto& pair : kwargs) {
|
||||
if (auto h = value_hash(pair.second); h.has_value()) {
|
||||
constants.push_back(*value_hash(pair.first));
|
||||
constants.push_back(*h);
|
||||
}
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
auto outputs =
|
||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
std::vector<array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
@@ -965,12 +1015,14 @@ void init_transforms(py::module_& m) {
|
||||
"compile",
|
||||
[](const py::function& fun,
|
||||
const py::object& inputs,
|
||||
const py::object& outputs) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
|
||||
const py::object& outputs,
|
||||
bool shapeless) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
"outputs"_a = std::nullopt,
|
||||
"shapeless"_a = false,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
@@ -990,6 +1042,12 @@ void init_transforms(py::module_& m) {
|
||||
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
|
||||
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
|
||||
Default: ``None``
|
||||
shapeless (bool, optional): A function compiled with the ``shapeless``
|
||||
option enabled will not be recompiled when the input shape changes. Not all
|
||||
functions can be compiled with ``shapeless`` enabled. Attempting to compile
|
||||
such functions with shapeless enabled will throw. Note, changing the number
|
||||
of dimensions or type of any input will result in a recompilation even with
|
||||
``shapeless`` set to ``True``. Default: ``False``
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
|
@@ -381,6 +381,164 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||
|
||||
def test_compile_kwargs(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
x = mx.array(1)
|
||||
y = mx.array(2)
|
||||
z = mx.array(3)
|
||||
out = fun(x, y=y, z=z)
|
||||
self.assertEqual(out.item(), 6)
|
||||
|
||||
def test_shapeless_compile(self):
|
||||
y = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x):
|
||||
return x + y
|
||||
|
||||
x = mx.array([1, 2])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||
|
||||
# The function is not recompiled, so the change
|
||||
# to y should not be reflected in the output
|
||||
y = 2
|
||||
x = mx.array([1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||
|
||||
# Type change recompiles
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||
fun(x, y=y, z=z)
|
||||
|
||||
def test_shapeless_compile(self):
|
||||
y = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x):
|
||||
return x + y
|
||||
|
||||
x = mx.array([1, 2])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||
|
||||
# The function is not recompiled, so the change
|
||||
# to y should not be reflected in the output
|
||||
y = 2
|
||||
x = mx.array([1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||
|
||||
# Type change recompiles
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||
|
||||
# Dim change recompiles
|
||||
x = mx.array([[1, 2, 3]])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]])))
|
||||
|
||||
def test_shapeless_compile_with_broadcasts(self):
|
||||
x = mx.ones((2, 2))
|
||||
y = mx.array([2, 2])
|
||||
|
||||
def fun(x, y):
|
||||
return x * y
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
|
||||
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
|
||||
y = mx.array([[3]])
|
||||
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
|
||||
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
|
||||
|
||||
def test_shapeless_compile_with_reduction(self):
|
||||
# Test shapeless compile with a reduction
|
||||
z = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x, y):
|
||||
return x + y.sum(0, keepdims=True) + z
|
||||
|
||||
x = mx.ones((2, 2), mx.int32)
|
||||
y = mx.ones((2, 2), mx.int32)
|
||||
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4)))
|
||||
x = mx.ones((3, 3), mx.int32)
|
||||
y = mx.ones((3, 3), mx.int32)
|
||||
z = 2
|
||||
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5)))
|
||||
|
||||
x1 = mx.array([[1, 2], [3, 4], [5, 6]])
|
||||
x2 = mx.array([[1, 2]])
|
||||
|
||||
def fun(x):
|
||||
return x * x.sum(-1, keepdims=True)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
mx.eval(cfun(x1))
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
z = fun(mx.array(1.0), 1.0)
|
||||
self.assertEqual(z.item(), 2.0)
|
||||
|
||||
z = fun(mx.array(1.0), 2.0)
|
||||
self.assertEqual(z.item(), 3.0)
|
||||
|
||||
z = fun(mx.array(1.0), y=1.0)
|
||||
self.assertEqual(z.item(), 2.0)
|
||||
|
||||
z = fun(mx.array(1.0), y=3.0)
|
||||
self.assertEqual(z.item(), 4.0)
|
||||
|
||||
# Test tuple
|
||||
@partial(mx.compile)
|
||||
def fun(x, y=(1, 2)):
|
||||
return x + y[0] + y[1]
|
||||
|
||||
z = fun(mx.array(1))
|
||||
self.assertEqual(z.item(), 4)
|
||||
|
||||
z = fun(mx.array(1), (2, 2))
|
||||
self.assertEqual(z.item(), 5)
|
||||
|
||||
z = fun(mx.array(1), (2, 1))
|
||||
self.assertEqual(z.item(), 4)
|
||||
|
||||
# Test bool
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), True)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), False)
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
# Test string
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y == "one":
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), "one")
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), "two")
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user