mlx/python/tests/mlx_tests.py
2023-11-30 11:12:53 -08:00

19 lines
411 B
Python

# Copyright © 2023 Apple Inc.
import os
import unittest
import mlx.core as mx
class MLXTestCase(unittest.TestCase):
def setUp(self):
self.default = mx.default_device()
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
mx.set_default_device(device)
def tearDown(self):
mx.set_default_device(self.default)