Compile primitive (#571)

* Compiled primitive with basic binary, unary graph-level fusion
This commit is contained in:
Awni Hannun
2024-02-05 06:51:22 -08:00
committed by GitHub
parent 31fea3758e
commit d75ae52ecd
15 changed files with 1088 additions and 75 deletions

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "mlx/array.h"
#include "mlx/compile.h"
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"

View File

@@ -190,6 +190,117 @@ class TestCompile(mlx_tests.MLXTestCase):
n_enable_compiled = count_prims(cfun(x))
self.assertEqual(n_compiled, n_enable_compiled)
def test_compile_two_input_grad(self):
def loss(w, x):
y = x * w
return (y * mx.exp(y)).sum()
x = mx.array([1.0, 0.5, 2.0, -0.5])
w = mx.array([-1.0, 0.3, 1.0, -0.9])
expected_grad = mx.grad(loss)(w, x)
compiled_grad = mx.compile(mx.grad(loss))(w, x)
self.assertTrue(mx.allclose(expected_grad, compiled_grad))
def test_vmap_compiled(self):
def simple_unary(x):
return -mx.exp(x)
x = mx.array([[1.0, 2.0], [2.0, 3.0]])
expected_out = mx.vmap(simple_unary)(x)
out = mx.vmap(mx.compile(simple_unary))(x)
self.assertTrue(mx.allclose(expected_out, out))
def simple_binary(x, y):
return mx.abs(mx.exp(x + y) + y)
x = mx.array([[1.0, -3.0], [0.5, -0.5]])
y = mx.array([[2.0, -1.0], [0.25, -0.25]])
expected_out = mx.vmap(simple_binary)(x, y)
out = mx.vmap(mx.compile(simple_binary))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
expected_out = mx.vmap(simple_binary, in_axes=(0, 1))(x, y)
out = mx.vmap(mx.compile(simple_binary), in_axes=(0, 1))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
y = mx.array([0.25, -0.25])
expected_out = mx.vmap(simple_binary, in_axes=(0, None))(x, y)
out = mx.vmap(mx.compile(simple_binary), in_axes=(0, None))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
def simple_unary_outer(x):
x = mx.abs(x)
@mx.compile
def simple_unary_inner(z):
return -mx.exp(x)
return simple_unary_inner(x)
expected_out = -mx.exp(mx.abs(x))
out = mx.vmap(simple_unary_outer)(x)
self.assertTrue(mx.allclose(expected_out, out))
def test_vjp_vjp_compiled(self):
def simple_unary(x):
return -mx.exp(x)
x = mx.array([[1.0, 2.0], [2.0, 3.0]])
y = mx.array([[1.0, 1.0], [1.0, 1.0]])
expected_out, expected_vjp_out = mx.vjp(simple_unary, (x,), (y,))
out, vjp_out = mx.vjp(mx.compile(simple_unary), (x,), (y,))
self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
expected_out, expected_jvp_out = mx.jvp(simple_unary, (x,), (y,))
out, jvp_out = mx.jvp(mx.compile(simple_unary), (x,), (y,))
self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
def simple_binary(x, y):
return mx.abs(mx.exp(x + y) + y)
x = mx.array([[1.0, -3.0], [0.5, -0.5]])
y = mx.array([[2.0, -1.0], [0.25, -0.25]])
cotans = mx.ones_like(x)
expected_out, expected_vjp_out = mx.vjp(simple_binary, (x, y), (cotans,))
out, vjp_out = mx.vjp(mx.compile(simple_binary), (x, y), (cotans,))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))
self.assertTrue(mx.allclose(expected_vjp_out[1], vjp_out[1]))
tans = (mx.ones_like(x), mx.ones_like(y))
expected_out, expected_jvp_out = mx.jvp(simple_binary, (x, y), tans)
out, jvp_out = mx.jvp(mx.compile(simple_binary), (x, y), tans)
self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
def test_transform_over_eval_compiled(self):
def outer(x):
y = mx.exp(mx.abs(x))
mx.eval(y)
return y.sum()
x = mx.array([2.0, -1.0, 0.5])
dfdx = mx.grad(outer)(x)
@mx.compile
def simple_unary(x):
return mx.exp(mx.abs(x))
def outer(x):
y = simple_unary(x)
mx.eval(y)
return y.sum()
cdfdx = mx.grad(outer)(x)
self.assertTrue(mx.allclose(dfdx, cdfdx))
if __name__ == "__main__":
unittest.main()