mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
This commit is contained in:
@@ -1371,6 +1371,37 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
np_eye_matrix = np.eye(5, 6, k=-2)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
def test_stack(self):
|
||||
a = mx.ones((2,))
|
||||
np_a = np.ones((2,))
|
||||
b = mx.ones((2,))
|
||||
np_b = np.ones((2,))
|
||||
|
||||
# One dimensional stack axis=0
|
||||
c = mx.stack([a, b])
|
||||
np_c = np.stack([np_a, np_b])
|
||||
self.assertTrue(np.array_equal(c, np_c))
|
||||
|
||||
# One dimensional stack axis=1
|
||||
c = mx.stack([a, b], axis=1)
|
||||
np_c = np.stack([np_a, np_b], axis=1)
|
||||
self.assertTrue(np.array_equal(c, np_c))
|
||||
|
||||
a = mx.ones((1, 2))
|
||||
np_a = np.ones((1, 2))
|
||||
b = mx.ones((1, 2))
|
||||
np_b = np.ones((1, 2))
|
||||
|
||||
# Two dimensional stack axis=0
|
||||
c = mx.stack([a, b])
|
||||
np_c = np.stack([np_a, np_b])
|
||||
self.assertTrue(np.array_equal(c, np_c))
|
||||
|
||||
# Two dimensional stack axis=1
|
||||
c = mx.stack([a, b], axis=1)
|
||||
np_c = np.stack([np_a, np_b], axis=1)
|
||||
self.assertTrue(np.array_equal(c, np_c))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user