mlx/python/tests/test_device.py

118 lines
3.7 KiB
Python

# Copyright © 2023 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
# Don't inherit from MLXTestCase to avoid call to setUp
class TestDefaultDevice(unittest.TestCase):
def test_mlx_default_device(self):
device = mx.default_device()
if mx.is_available(mx.gpu):
self.assertEqual(device, mx.Device(mx.gpu))
self.assertEqual(str(device), "Device(gpu, 0)")
self.assertEqual(device, mx.gpu)
self.assertEqual(mx.gpu, device)
else:
self.assertEqual(device.type, mx.Device(mx.cpu))
with self.assertRaises(ValueError):
mx.set_default_device(mx.gpu)
class TestDevice(mlx_tests.MLXTestCase):
def test_device(self):
device = mx.default_device()
cpu = mx.Device(mx.cpu)
mx.set_default_device(cpu)
self.assertEqual(mx.default_device(), cpu)
self.assertEqual(str(cpu), "Device(cpu, 0)")
mx.set_default_device(mx.cpu)
self.assertEqual(mx.default_device(), mx.cpu)
self.assertEqual(cpu, mx.cpu)
self.assertEqual(mx.cpu, cpu)
# Restore device
mx.set_default_device(device)
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
def test_device_context(self):
default = mx.default_device()
diff = mx.cpu if default == mx.gpu else mx.gpu
self.assertNotEqual(default, diff)
with mx.stream(diff):
a = mx.add(mx.zeros((2, 2)), mx.ones((2, 2)))
mx.eval(a)
self.assertEqual(mx.default_device(), diff)
self.assertEqual(mx.default_device(), default)
def test_op_on_device(self):
x = mx.array(1.0)
y = mx.array(1.0)
a = mx.add(x, y, stream=None)
b = mx.add(x, y, stream=mx.default_device())
self.assertEqual(a.item(), b.item())
b = mx.add(x, y, stream=mx.cpu)
self.assertEqual(a.item(), b.item())
if mx.metal.is_available():
b = mx.add(x, y, stream=mx.gpu)
self.assertEqual(a.item(), b.item())
class TestStream(mlx_tests.MLXTestCase):
def test_stream(self):
s1 = mx.default_stream(mx.default_device())
self.assertEqual(s1.device, mx.default_device())
s2 = mx.new_stream(mx.default_device())
self.assertEqual(s2.device, mx.default_device())
self.assertNotEqual(s1, s2)
if mx.is_available(mx.gpu):
s_gpu = mx.default_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
with self.assertRaises(ValueError):
mx.default_stream(mx.gpu)
s_cpu = mx.default_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
s_cpu = mx.new_stream(mx.cpu)
self.assertEqual(s_cpu.device, mx.cpu)
if mx.is_available(mx.gpu):
s_gpu = mx.new_stream(mx.gpu)
self.assertEqual(s_gpu.device, mx.gpu)
else:
with self.assertRaises(ValueError):
mx.new_stream(mx.gpu)
def test_op_on_stream(self):
x = mx.array(1.0)
y = mx.array(1.0)
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
if mx.is_available(mx.gpu):
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
self.assertEqual(a.item(), b.item())
s_gpu = mx.new_stream(mx.gpu)
b = mx.add(x, y, stream=s_gpu)
self.assertEqual(a.item(), b.item())
b = mx.add(x, y, stream=mx.default_stream(mx.cpu))
self.assertEqual(a.item(), b.item())
s_cpu = mx.new_stream(mx.cpu)
b = mx.add(x, y, stream=s_cpu)
self.assertEqual(a.item(), b.item())
if __name__ == "__main__":
mlx_tests.MLXTestRunner()