mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00

* added tri / tril / triu * fixed tests * ctest tests * tri overload and simplified tests * changes from comment * more tests for m * ensure assert if not 2-D * remove broadcast_to * minor tweaks --------- Co-authored-by: Awni Hannun <awni@apple.com>
36 lines
943 B
Python
36 lines
943 B
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import os
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
|
|
|
|
class MLXTestCase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.default = mx.default_device()
|
|
device = os.getenv("DEVICE", None)
|
|
if device is not None:
|
|
device = getattr(mx, device)
|
|
mx.set_default_device(device)
|
|
|
|
def tearDown(self):
|
|
mx.set_default_device(self.default)
|
|
|
|
def assertEqualArray(
|
|
self,
|
|
mx_res: mx.array,
|
|
expected: mx.array,
|
|
atol=1e-2,
|
|
rtol=1e-2,
|
|
**kwargs,
|
|
):
|
|
assert tuple(mx_res.shape) == tuple(
|
|
expected.shape
|
|
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
|
|
assert (
|
|
mx_res.dtype == expected.dtype
|
|
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
|
|
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|