mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Make shape a tuple (#591)
* shape tuple * also remove simplify from docs * rebase
This commit is contained in:
@@ -94,7 +94,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.itemsize, 4)
|
||||
self.assertEqual(x.nbytes, 4)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.dtype, mx.int32)
|
||||
self.assertEqual(x.item(), 1)
|
||||
self.assertTrue(isinstance(x.item(), int))
|
||||
@@ -116,7 +116,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(1.0)
|
||||
self.assertEqual(x.size, 1)
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
self.assertEqual(x.item(), 1.0)
|
||||
self.assertTrue(isinstance(x.item(), float))
|
||||
@@ -124,14 +124,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(False)
|
||||
self.assertEqual(x.size, 1)
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.dtype, mx.bool_)
|
||||
self.assertEqual(x.item(), False)
|
||||
self.assertTrue(isinstance(x.item(), bool))
|
||||
|
||||
x = mx.array(complex(1, 1))
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.dtype, mx.complex64)
|
||||
self.assertEqual(x.item(), complex(1, 1))
|
||||
self.assertTrue(isinstance(x.item(), complex))
|
||||
@@ -139,7 +139,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([True, False, True])
|
||||
self.assertEqual(x.dtype, mx.bool_)
|
||||
self.assertEqual(x.ndim, 1)
|
||||
self.assertEqual(x.shape, [3])
|
||||
self.assertEqual(x.shape, (3,))
|
||||
self.assertEqual(len(x), 3)
|
||||
|
||||
x = mx.array([True, False, True], mx.float32)
|
||||
@@ -148,7 +148,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([0, 1, 2])
|
||||
self.assertEqual(x.dtype, mx.int32)
|
||||
self.assertEqual(x.ndim, 1)
|
||||
self.assertEqual(x.shape, [3])
|
||||
self.assertEqual(x.shape, (3,))
|
||||
|
||||
x = mx.array([0, 1, 2], mx.float32)
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
@@ -156,12 +156,12 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([0.0, 1.0, 2.0])
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
self.assertEqual(x.ndim, 1)
|
||||
self.assertEqual(x.shape, [3])
|
||||
self.assertEqual(x.shape, (3,))
|
||||
|
||||
x = mx.array([1j, 1 + 0j])
|
||||
self.assertEqual(x.dtype, mx.complex64)
|
||||
self.assertEqual(x.ndim, 1)
|
||||
self.assertEqual(x.shape, [2])
|
||||
self.assertEqual(x.shape, (2,))
|
||||
|
||||
# From tuple
|
||||
x = mx.array((1, 2, 3), mx.int32)
|
||||
@@ -181,17 +181,17 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
def test_construction_from_lists(self):
|
||||
x = mx.array([])
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [0])
|
||||
self.assertEqual(x.shape, (0,))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
x = mx.array([[], [], []])
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [3, 0])
|
||||
self.assertEqual(x.shape, (3, 0))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
x = mx.array([[[], []], [[], []], [[], []]])
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [3, 2, 0])
|
||||
self.assertEqual(x.shape, (3, 2, 0))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
# Check failure cases
|
||||
@@ -436,19 +436,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a = np.array([])
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [0])
|
||||
self.assertEqual(x.shape, (0,))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
a = np.array([[], [], []])
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [3, 0])
|
||||
self.assertEqual(x.shape, (3, 0))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
a = np.array([[[], []], [[], []], [[], []]])
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [3, 2, 0])
|
||||
self.assertEqual(x.shape, (3, 2, 0))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
# Content test
|
||||
@@ -456,7 +456,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
self.assertEqual(x.ndim, 3)
|
||||
self.assertEqual(x.shape, [3, 5, 4])
|
||||
self.assertEqual(x.shape, (3, 5, 4))
|
||||
|
||||
y = np.asarray(x)
|
||||
self.assertTrue(np.allclose(a, y))
|
||||
@@ -465,7 +465,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.dtype, mx.int32)
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
# mlx to numpy test
|
||||
@@ -483,7 +483,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = np.array(cvals)
|
||||
y = mx.array(x)
|
||||
self.assertEqual(y.dtype, mx.complex64)
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.tolist(), cvals)
|
||||
|
||||
y = mx.array([0j, 1, 1 + 1j])
|
||||
|
@@ -579,5 +579,9 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertListEqual(r.shape, t.shape)
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -44,7 +44,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
c_np = np.convolve(a_np, v_np, mode=mode)
|
||||
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
|
||||
|
||||
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
|
||||
self.assertEqual(c_mx.shape, c_np.shape)
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
@@ -102,7 +102,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
@@ -141,7 +141,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
out_pt = torch.conv1d(in_pt, wt_pt)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
@@ -228,12 +228,12 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
@@ -309,7 +309,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
@@ -419,12 +419,12 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
|
@@ -13,7 +13,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.constant(value, dtype)
|
||||
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||
for shape in [(3,), (3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(mx.zeros(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -24,7 +24,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
std = 1.0
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.normal(mean, std, dtype=dtype)
|
||||
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||
for shape in [(3,), (3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -36,7 +36,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.uniform(low, high, dtype)
|
||||
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||
for shape in [(3,), (3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -46,7 +46,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.identity(dtype)
|
||||
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||
for shape in [(3,), (3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.zeros((3, 3)))
|
||||
self.assertTrue(mx.array_equal(result, mx.eye(3)))
|
||||
self.assertEqual(result.dtype, dtype)
|
||||
@@ -56,7 +56,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_glorot_normal(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.glorot_normal(dtype)
|
||||
for shape in [[3, 3], [3, 3, 3]]:
|
||||
for shape in [(3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -65,7 +65,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_glorot_uniform(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.glorot_uniform(dtype)
|
||||
for shape in [[3, 3], [3, 3, 3]]:
|
||||
for shape in [(3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -74,7 +74,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_he_normal(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.he_normal(dtype)
|
||||
for shape in [[3, 3], [3, 3, 3]]:
|
||||
for shape in [(3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
@@ -83,7 +83,7 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
def test_he_uniform(self):
|
||||
for dtype in [mx.float32, mx.float16]:
|
||||
initializer = init.he_uniform(dtype)
|
||||
for shape in [[3, 3], [3, 3, 3]]:
|
||||
for shape in [(3, 3), (3, 3, 3)]:
|
||||
result = initializer(mx.array(np.empty(shape)))
|
||||
with self.subTest(shape=shape):
|
||||
self.assertEqual(result.shape, shape)
|
||||
|
@@ -136,20 +136,20 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Identity()
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape))
|
||||
self.assertEqual(inputs.shape, outputs.shape)
|
||||
|
||||
def test_linear(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 8))
|
||||
self.assertEqual(outputs.shape, (10, 8))
|
||||
|
||||
def test_bilinear(self):
|
||||
inputs1 = mx.zeros((10, 2))
|
||||
inputs2 = mx.zeros((10, 4))
|
||||
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
|
||||
outputs = layer(inputs1, inputs2)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 6))
|
||||
self.assertEqual(outputs.shape, (10, 6))
|
||||
|
||||
def test_group_norm(self):
|
||||
x = mx.arange(100, dtype=mx.float32)
|
||||
@@ -573,12 +573,12 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
|
||||
c.weight = mx.ones_like(c.weight)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
|
||||
self.assertEqual(y.shape, (N, L - ks + 1, C_out))
|
||||
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
|
||||
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
|
||||
self.assertEqual(y.shape, (N, (L - ks + 1) // 2, C_out))
|
||||
self.assertTrue("bias" in c.parameters())
|
||||
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
|
||||
@@ -588,7 +588,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
x = mx.ones((4, 8, 8, 3))
|
||||
c = nn.Conv2d(3, 1, 8)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 1, 1, 1])
|
||||
self.assertEqual(y.shape, (4, 1, 1, 1))
|
||||
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
|
||||
y = c(x)
|
||||
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
|
||||
@@ -596,13 +596,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
# 3x3 conv no padding stride 1
|
||||
c = nn.Conv2d(3, 8, 3)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 6, 6, 8])
|
||||
self.assertEqual(y.shape, (4, 6, 6, 8))
|
||||
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
|
||||
# 3x3 conv padding 1 stride 1
|
||||
c = nn.Conv2d(3, 8, 3, padding=1)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 8, 8, 8])
|
||||
self.assertEqual(y.shape, (4, 8, 8, 8))
|
||||
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
|
||||
@@ -624,14 +624,14 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
# 3x3 conv no padding stride 2
|
||||
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 3, 3, 8])
|
||||
self.assertEqual(y.shape, (4, 3, 3, 8))
|
||||
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
|
||||
def test_sequential(self):
|
||||
x = mx.ones((10, 2))
|
||||
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
|
||||
y = m(x)
|
||||
self.assertEqual(y.shape, [10, 1])
|
||||
self.assertEqual(y.shape, (10, 1))
|
||||
params = m.parameters()
|
||||
self.assertTrue("layers" in params)
|
||||
self.assertEqual(len(params["layers"]), 3)
|
||||
@@ -667,7 +667,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
x = mx.arange(10)
|
||||
y = m(x)
|
||||
|
||||
self.assertEqual(y.shape, [10, 16])
|
||||
self.assertEqual(y.shape, (10, 16))
|
||||
similarities = y @ y.T
|
||||
self.assertLess(
|
||||
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
||||
@@ -686,19 +686,19 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.relu(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_leaky_relu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.leaky_relu(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.LeakyReLU(negative_slope=0.1)(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_elu(self):
|
||||
@@ -707,21 +707,21 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([1.0, -0.6321, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.ELU(alpha=1.1)(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([1.0, -0.6953, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_relu6(self):
|
||||
x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])
|
||||
y = nn.relu6(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softmax(self):
|
||||
@@ -730,7 +730,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.6652, 0.0900, 0.2447])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softplus(self):
|
||||
@@ -739,7 +739,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([1.3133, 0.3133, 0.6931])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softsign(self):
|
||||
@@ -748,7 +748,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.5, -0.5, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softshrink(self):
|
||||
@@ -757,13 +757,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.5, -0.5, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.Softshrink(lambd=0.7)(x)
|
||||
expected_y = mx.array([0.3, -0.3, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_celu(self):
|
||||
@@ -772,13 +772,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([1.0, -0.6321, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.CELU(alpha=1.1)(x)
|
||||
expected_y = mx.array([1.0, -0.6568, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_log_softmax(self):
|
||||
@@ -787,7 +787,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([-2.4076, -1.4076, -0.4076])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_log_sigmoid(self):
|
||||
@@ -796,7 +796,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([-0.3133, -1.3133, -0.6931])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_prelu(self):
|
||||
@@ -817,7 +817,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_glu(self):
|
||||
|
@@ -12,12 +12,12 @@ import numpy as np
|
||||
class TestOps(mlx_tests.MLXTestCase):
|
||||
def test_full_ones_zeros(self):
|
||||
x = mx.full(2, 3.0)
|
||||
self.assertEqual(x.shape, [2])
|
||||
self.assertEqual(x.shape, (2,))
|
||||
self.assertEqual(x.tolist(), [3.0, 3.0])
|
||||
|
||||
x = mx.full((2, 3), 2.0)
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
self.assertEqual(x.shape, [2, 3])
|
||||
self.assertEqual(x.shape, (2, 3))
|
||||
self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]])
|
||||
|
||||
x = mx.full([3, 2], mx.array([False, True]))
|
||||
@@ -28,11 +28,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]])
|
||||
|
||||
x = mx.zeros(2)
|
||||
self.assertEqual(x.shape, [2])
|
||||
self.assertEqual(x.shape, (2,))
|
||||
self.assertEqual(x.tolist(), [0.0, 0.0])
|
||||
|
||||
x = mx.ones(2)
|
||||
self.assertEqual(x.shape, [2])
|
||||
self.assertEqual(x.shape, (2,))
|
||||
self.assertEqual(x.tolist(), [1.0, 1.0])
|
||||
|
||||
for t in [mx.bool_, mx.int32, mx.float32]:
|
||||
@@ -530,10 +530,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_move_swap_axes(self):
|
||||
x = mx.zeros((2, 3, 4))
|
||||
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
|
||||
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
|
||||
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
|
||||
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
|
||||
self.assertEqual(mx.moveaxis(x, 0, 2).shape, (3, 4, 2))
|
||||
self.assertEqual(x.moveaxis(0, 2).shape, (3, 4, 2))
|
||||
self.assertEqual(mx.swapaxes(x, 0, 2).shape, (4, 3, 2))
|
||||
self.assertEqual(x.swapaxes(0, 2).shape, (4, 3, 2))
|
||||
|
||||
def test_sum(self):
|
||||
x = mx.array(
|
||||
@@ -545,7 +545,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.sum(x).item(), 9)
|
||||
y = mx.sum(x, keepdims=True)
|
||||
self.assertEqual(y, mx.array(9))
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5])
|
||||
self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6])
|
||||
@@ -585,7 +585,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.prod(x).item(), 18)
|
||||
y = mx.prod(x, keepdims=True)
|
||||
self.assertEqual(y, mx.array(18))
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6])
|
||||
self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9])
|
||||
@@ -600,11 +600,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.min(x).item(), 1)
|
||||
self.assertEqual(mx.max(x).item(), 4)
|
||||
y = mx.min(x, keepdims=True)
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
self.assertEqual(y, mx.array(1))
|
||||
|
||||
y = mx.max(x, keepdims=True)
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
self.assertEqual(y, mx.array(4))
|
||||
|
||||
self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2])
|
||||
@@ -670,7 +670,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.mean(x).item(), 2.5)
|
||||
y = mx.mean(x, keepdims=True)
|
||||
self.assertEqual(y, mx.array(2.5))
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
|
||||
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
|
||||
@@ -685,7 +685,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.var(x).item(), 1.25)
|
||||
y = mx.var(x, keepdims=True)
|
||||
self.assertEqual(y, mx.array(1.25))
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0])
|
||||
self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25])
|
||||
@@ -888,7 +888,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.array([[True, False], [True, True]])
|
||||
|
||||
self.assertFalse(mx.all(a).item())
|
||||
self.assertEqual(mx.all(a, keepdims=True).shape, [1, 1])
|
||||
self.assertEqual(mx.all(a, keepdims=True).shape, (1, 1))
|
||||
self.assertFalse(mx.all(a, axis=[0, 1]).item())
|
||||
self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False])
|
||||
self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True])
|
||||
@@ -899,7 +899,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.array([[True, False], [False, False]])
|
||||
|
||||
self.assertTrue(mx.any(a).item())
|
||||
self.assertEqual(mx.any(a, keepdims=True).shape, [1, 1])
|
||||
self.assertEqual(mx.any(a, keepdims=True).shape, (1, 1))
|
||||
self.assertTrue(mx.any(a, axis=[0, 1]).item())
|
||||
self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False])
|
||||
self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False])
|
||||
@@ -956,22 +956,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
a_npy_taken = np.take(a_npy, idx_npy)
|
||||
a_mlx_taken = mx.take(a_mlx, idx_mlx)
|
||||
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
a_npy_taken = np.take(a_npy, idx_npy, axis=0)
|
||||
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0)
|
||||
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
a_npy_taken = np.take(a_npy, idx_npy, axis=1)
|
||||
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1)
|
||||
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
a_npy_taken = np.take(a_npy, idx_npy, axis=2)
|
||||
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2)
|
||||
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
def test_take_along_axis(self):
|
||||
@@ -1400,13 +1400,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
|
||||
|
||||
a = mx.zeros((1, 1, 1))
|
||||
self.assertEqual(mx.pad(a, 1).shape, [3, 3, 3])
|
||||
self.assertEqual(mx.pad(a, (1,)).shape, [3, 3, 3])
|
||||
self.assertEqual(mx.pad(a, [1]).shape, [3, 3, 3])
|
||||
self.assertEqual(mx.pad(a, (1, 2)).shape, [4, 4, 4])
|
||||
self.assertEqual(mx.pad(a, [(1, 2)]).shape, [4, 4, 4])
|
||||
self.assertEqual(mx.pad(a, ((1, 2),)).shape, [4, 4, 4])
|
||||
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, [4, 4, 5])
|
||||
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
|
||||
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))
|
||||
self.assertEqual(mx.pad(a, [1]).shape, (3, 3, 3))
|
||||
self.assertEqual(mx.pad(a, (1, 2)).shape, (4, 4, 4))
|
||||
self.assertEqual(mx.pad(a, [(1, 2)]).shape, (4, 4, 4))
|
||||
self.assertEqual(mx.pad(a, ((1, 2),)).shape, (4, 4, 4))
|
||||
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, (4, 4, 5))
|
||||
|
||||
# Test grads
|
||||
a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32))
|
||||
@@ -1490,19 +1490,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_squeeze_expand(self):
|
||||
a = mx.zeros((2, 1, 2, 1))
|
||||
self.assertEqual(mx.squeeze(a).shape, [2, 2])
|
||||
self.assertEqual(mx.squeeze(a, 1).shape, [2, 2, 1])
|
||||
self.assertEqual(mx.squeeze(a, [1, 3]).shape, [2, 2])
|
||||
self.assertEqual(a.squeeze().shape, [2, 2])
|
||||
self.assertEqual(a.squeeze(1).shape, [2, 2, 1])
|
||||
self.assertEqual(a.squeeze([1, 3]).shape, [2, 2])
|
||||
self.assertEqual(mx.squeeze(a).shape, (2, 2))
|
||||
self.assertEqual(mx.squeeze(a, 1).shape, (2, 2, 1))
|
||||
self.assertEqual(mx.squeeze(a, [1, 3]).shape, (2, 2))
|
||||
self.assertEqual(a.squeeze().shape, (2, 2))
|
||||
self.assertEqual(a.squeeze(1).shape, (2, 2, 1))
|
||||
self.assertEqual(a.squeeze([1, 3]).shape, (2, 2))
|
||||
|
||||
a = mx.zeros((2, 2))
|
||||
self.assertEqual(mx.squeeze(a).shape, [2, 2])
|
||||
self.assertEqual(mx.squeeze(a).shape, (2, 2))
|
||||
|
||||
self.assertEqual(mx.expand_dims(a, 0).shape, [1, 2, 2])
|
||||
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, [1, 1, 2, 2])
|
||||
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, [1, 2, 2, 1])
|
||||
self.assertEqual(mx.expand_dims(a, 0).shape, (1, 2, 2))
|
||||
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, (1, 1, 2, 2))
|
||||
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
|
||||
|
||||
def test_sort(self):
|
||||
shape = (3, 4, 5)
|
||||
@@ -1603,12 +1603,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_flatten(self):
|
||||
x = mx.zeros([2, 3, 4])
|
||||
self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4])
|
||||
self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4])
|
||||
self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4])
|
||||
self.assertEqual(x.flatten().shape, [2 * 3 * 4])
|
||||
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
|
||||
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
|
||||
self.assertEqual(mx.flatten(x).shape, (2 * 3 * 4,))
|
||||
self.assertEqual(mx.flatten(x, start_axis=1).shape, (2, 3 * 4))
|
||||
self.assertEqual(mx.flatten(x, end_axis=1).shape, (2 * 3, 4))
|
||||
self.assertEqual(x.flatten().shape, (2 * 3 * 4,))
|
||||
self.assertEqual(x.flatten(start_axis=1).shape, (2, 3 * 4))
|
||||
self.assertEqual(x.flatten(end_axis=1).shape, (2 * 3, 4))
|
||||
|
||||
def test_clip(self):
|
||||
a = np.array([1, 4, 3, 8, 5], np.int32)
|
||||
|
@@ -38,19 +38,19 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.array_equal(k2, r2))
|
||||
|
||||
keys = mx.random.split(key, 10)
|
||||
self.assertEqual(keys.shape, [10, 2])
|
||||
self.assertEqual(keys.shape, (10, 2))
|
||||
|
||||
def test_uniform(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.uniform(key=key)
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.shape, ())
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
b = mx.random.uniform(key=key)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
a = mx.random.uniform(shape=(2, 3))
|
||||
self.assertEqual(a.shape, [2, 3])
|
||||
self.assertEqual(a.shape, (2, 3))
|
||||
|
||||
a = mx.random.uniform(shape=(1000,), low=-1, high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
@@ -66,14 +66,14 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
def test_normal(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.normal(key=key)
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.shape, ())
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
b = mx.random.normal(key=key)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
a = mx.random.normal(shape=(2, 3))
|
||||
self.assertEqual(a.shape, [2, 3])
|
||||
self.assertEqual(a.shape, (2, 3))
|
||||
|
||||
## Generate in float16 or bfloat16
|
||||
for t in [mx.float16, mx.bfloat16]:
|
||||
@@ -84,10 +84,10 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_randint(self):
|
||||
a = mx.random.randint(0, 1, [])
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.shape, ())
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
shape = [88]
|
||||
shape = (88,)
|
||||
low = mx.array(3)
|
||||
high = mx.array(15)
|
||||
|
||||
@@ -100,7 +100,7 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
b = mx.random.randint(low, high, shape, key=key)
|
||||
self.assertListEqual(a.tolist(), b.tolist())
|
||||
|
||||
shape = [3, 4]
|
||||
shape = (3, 4)
|
||||
low = mx.reshape(mx.array([0] * 3), [3, 1])
|
||||
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
|
||||
|
||||
@@ -119,20 +119,20 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_bernoulli(self):
|
||||
a = mx.random.bernoulli()
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.shape, ())
|
||||
self.assertEqual(a.dtype, mx.bool_)
|
||||
|
||||
a = mx.random.bernoulli(mx.array(0.5), [5])
|
||||
self.assertEqual(a.shape, [5])
|
||||
self.assertEqual(a.shape, (5,))
|
||||
|
||||
a = mx.random.bernoulli(mx.array([2.0, -2.0]))
|
||||
self.assertEqual(a.tolist(), [True, False])
|
||||
self.assertEqual(a.shape, [2])
|
||||
self.assertEqual(a.shape, (2,))
|
||||
|
||||
p = mx.array([0.1, 0.2, 0.3])
|
||||
mx.reshape(p, [1, 3])
|
||||
x = mx.random.bernoulli(p, [4, 3])
|
||||
self.assertEqual(x.shape, [4, 3])
|
||||
self.assertEqual(x.shape, (4, 3))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.bernoulli(p, [2]) # Bad shape
|
||||
@@ -153,14 +153,14 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
|
||||
a = mx.random.truncated_normal(lower, upper)
|
||||
|
||||
self.assertEqual(a.shape, [3, 2])
|
||||
self.assertEqual(a.shape, (3, 2))
|
||||
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
|
||||
|
||||
a = mx.random.truncated_normal(2.0, -2.0)
|
||||
self.assertTrue(mx.all(a == 2.0).item())
|
||||
|
||||
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
|
||||
self.assertEqual(a.shape, [542, 399])
|
||||
self.assertEqual(a.shape, (542, 399))
|
||||
|
||||
lower = mx.array([-2.0, -1.0])
|
||||
higher = mx.array([1.0, 2.0, 3.0])
|
||||
@@ -174,7 +174,7 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_gumbel(self):
|
||||
samples = mx.random.gumbel(shape=(100, 100))
|
||||
self.assertEqual(samples.shape, [100, 100])
|
||||
self.assertEqual(samples.shape, (100, 100))
|
||||
self.assertEqual(samples.dtype, mx.float32)
|
||||
mean = 0.5772
|
||||
# Std deviation of the sample mean is small (<0.02),
|
||||
@@ -187,23 +187,23 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_categorical(self):
|
||||
logits = mx.zeros((10, 20))
|
||||
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
|
||||
self.assertEqual(mx.random.categorical(logits, 0).shape, [20])
|
||||
self.assertEqual(mx.random.categorical(logits, 1).shape, [10])
|
||||
self.assertEqual(mx.random.categorical(logits, -1).shape, (10,))
|
||||
self.assertEqual(mx.random.categorical(logits, 0).shape, (20,))
|
||||
self.assertEqual(mx.random.categorical(logits, 1).shape, (10,))
|
||||
|
||||
out = mx.random.categorical(logits)
|
||||
self.assertEqual(out.shape, [10])
|
||||
self.assertEqual(out.shape, (10,))
|
||||
self.assertEqual(out.dtype, mx.uint32)
|
||||
self.assertTrue(mx.max(out).item() < 20)
|
||||
|
||||
out = mx.random.categorical(logits, 0, [5, 20])
|
||||
self.assertEqual(out.shape, [5, 20])
|
||||
self.assertEqual(out.shape, (5, 20))
|
||||
self.assertTrue(mx.max(out).item() < 10)
|
||||
|
||||
out = mx.random.categorical(logits, 1, num_samples=7)
|
||||
self.assertEqual(out.shape, [10, 7])
|
||||
self.assertEqual(out.shape, (10, 7))
|
||||
out = mx.random.categorical(logits, 0, num_samples=7)
|
||||
self.assertEqual(out.shape, [20, 7])
|
||||
self.assertEqual(out.shape, (20, 7))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
|
||||
|
Reference in New Issue
Block a user