mlx/python/tests/test_eval.py
Awni Hannun 1b97b2958b
Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00

38 lines
805 B
Python

# Copyright © 2023 Apple Inc.
import unittest
from functools import partial
import mlx.core as mx
import mlx_tests
class TestEval(mlx_tests.MLXTestCase):
def test_eval(self):
arrs = [mx.ones((2, 2)) for _ in range(4)]
mx.eval(*arrs)
for x in arrs:
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
def test_retain_graph(self):
def fun(x):
y = 3 * x
mx.eval(y)
return 2 * y
dfun_dx = mx.grad(fun)
y = dfun_dx(mx.array(1.0))
self.assertEqual(y.item(), 6.0)
def test_eval_mixed(self):
x = mx.array(1) + 1 + 1
y = 0
z = "hello"
state = [x, y, z]
mx.eval(state)
self.assertEqual(x.item(), 3)
if __name__ == "__main__":
unittest.main()