mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
17 lines
379 B
Python
17 lines
379 B
Python
import os
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
|
|
|
|
class MLXTestCase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.default = mx.default_device()
|
|
device = os.getenv("DEVICE", None)
|
|
if device is not None:
|
|
device = getattr(mx, device)
|
|
mx.set_default_device(device)
|
|
|
|
def tearDown(self):
|
|
mx.set_default_device(self.default)
|