mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-24 12:18:20 +08:00

* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
38 lines
805 B
Python
38 lines
805 B
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import unittest
|
|
from functools import partial
|
|
|
|
import mlx.core as mx
|
|
import mlx_tests
|
|
|
|
|
|
class TestEval(mlx_tests.MLXTestCase):
|
|
def test_eval(self):
|
|
arrs = [mx.ones((2, 2)) for _ in range(4)]
|
|
mx.eval(*arrs)
|
|
for x in arrs:
|
|
self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
|
|
|
|
def test_retain_graph(self):
|
|
def fun(x):
|
|
y = 3 * x
|
|
mx.eval(y)
|
|
return 2 * y
|
|
|
|
dfun_dx = mx.grad(fun)
|
|
y = dfun_dx(mx.array(1.0))
|
|
self.assertEqual(y.item(), 6.0)
|
|
|
|
def test_eval_mixed(self):
|
|
x = mx.array(1) + 1 + 1
|
|
y = 0
|
|
z = "hello"
|
|
state = [x, y, z]
|
|
mx.eval(state)
|
|
self.assertEqual(x.item(), 3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|