From 975e265f74951311b2f2935ef663af14ec373c1c Mon Sep 17 00:00:00 2001 From: Avikant Srivastava Date: Thu, 11 Jan 2024 20:17:29 +0530 Subject: [PATCH] feat: Add numpy constants (#428) * add numpy constants * feat: add unittests * add newaxis * add test for newaxis transformation * refactor --- python/src/CMakeLists.txt | 1 + python/src/constants.cpp | 24 ++++++++++++++ python/src/mlx.cpp | 2 ++ python/tests/test_constants.py | 59 ++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+) create mode 100644 python/src/constants.cpp create mode 100644 python/tests/test_constants.py diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 1ad9d207d..1ba037fdc 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -12,6 +12,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/constants.cpp b/python/src/constants.cpp new file mode 100644 index 000000000..94658b586 --- /dev/null +++ b/python/src/constants.cpp @@ -0,0 +1,24 @@ +// init_constants.cpp + +#include +#include + +namespace py = pybind11; + +void init_constants(py::module_& m) { + m.attr("Inf") = std::numeric_limits::infinity(); + m.attr("Infinity") = std::numeric_limits::infinity(); + m.attr("NAN") = NAN; + m.attr("NINF") = -std::numeric_limits::infinity(); + m.attr("NZERO") = -0.0; + m.attr("NaN") = NAN; + m.attr("PINF") = std::numeric_limits::infinity(); + m.attr("PZERO") = 0.0; + m.attr("e") = 2.71828182845904523536028747135266249775724709369995; + m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421; + m.attr("inf") = std::numeric_limits::infinity(); + m.attr("infty") = std::numeric_limits::infinity(); + m.attr("nan") = NAN; + m.attr("newaxis") = pybind11::none(); + m.attr("pi") = 3.1415926535897932384626433; +} \ No newline at end of file diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index d7cf15751..81626e565 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -16,6 +16,7 @@ void init_transforms(py::module_&); void init_random(py::module_&); void init_fft(py::module_&); void init_linalg(py::module_&); +void init_constants(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -31,5 +32,6 @@ PYBIND11_MODULE(core, m) { init_random(m); init_fft(m); init_linalg(m); + init_constants(m); m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/tests/test_constants.py b/python/tests/test_constants.py new file mode 100644 index 000000000..11a466e03 --- /dev/null +++ b/python/tests/test_constants.py @@ -0,0 +1,59 @@ +# Copyright © 2023 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + + +class TestConstants(mlx_tests.MLXTestCase): + def test_constants_values(self): + # Check if mlx constants match expected values + self.assertAlmostEqual(mx.Inf, float("inf")) + self.assertAlmostEqual(mx.Infinity, float("inf")) + self.assertTrue(np.isnan(mx.NAN)) + self.assertAlmostEqual(mx.NINF, float("-inf")) + self.assertEqual(mx.NZERO, -0.0) + self.assertTrue(np.isnan(mx.NaN)) + self.assertAlmostEqual(mx.PINF, float("inf")) + self.assertEqual(mx.PZERO, 0.0) + self.assertAlmostEqual( + mx.e, 2.71828182845904523536028747135266249775724709369995 + ) + self.assertAlmostEqual( + mx.euler_gamma, 0.5772156649015328606065120900824024310421 + ) + self.assertAlmostEqual(mx.inf, float("inf")) + self.assertAlmostEqual(mx.infty, float("inf")) + self.assertTrue(np.isnan(mx.nan)) + self.assertIsNone(mx.newaxis) + self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433) + + def test_constants_availability(self): + # Check if mlx constants are available + self.assertTrue(hasattr(mx, "Inf")) + self.assertTrue(hasattr(mx, "Infinity")) + self.assertTrue(hasattr(mx, "NAN")) + self.assertTrue(hasattr(mx, "NINF")) + self.assertTrue(hasattr(mx, "NaN")) + self.assertTrue(hasattr(mx, "PINF")) + self.assertTrue(hasattr(mx, "NZERO")) + self.assertTrue(hasattr(mx, "PZERO")) + self.assertTrue(hasattr(mx, "e")) + self.assertTrue(hasattr(mx, "euler_gamma")) + self.assertTrue(hasattr(mx, "inf")) + self.assertTrue(hasattr(mx, "infty")) + self.assertTrue(hasattr(mx, "nan")) + self.assertTrue(hasattr(mx, "newaxis")) + self.assertTrue(hasattr(mx, "pi")) + + def test_newaxis_for_reshaping_arrays(self): + arr_1d = mx.array([1, 2, 3, 4, 5]) + arr_2d_column = arr_1d[:, mx.newaxis] + expected_result = mx.array([[1], [2], [3], [4], [5]]) + self.assertTrue(mx.array_equal(arr_2d_column, expected_result)) + + +if __name__ == "__main__": + unittest.main()