mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
feat: Add numpy constants (#428)
* add numpy constants * feat: add unittests * add newaxis * add test for newaxis transformation * refactor
This commit is contained in:
parent
c92a134b0d
commit
975e265f74
@ -12,6 +12,7 @@ pybind11_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||||
|
24
python/src/constants.cpp
Normal file
24
python/src/constants.cpp
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
// init_constants.cpp
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
void init_constants(py::module_& m) {
|
||||||
|
m.attr("Inf") = std::numeric_limits<double>::infinity();
|
||||||
|
m.attr("Infinity") = std::numeric_limits<double>::infinity();
|
||||||
|
m.attr("NAN") = NAN;
|
||||||
|
m.attr("NINF") = -std::numeric_limits<double>::infinity();
|
||||||
|
m.attr("NZERO") = -0.0;
|
||||||
|
m.attr("NaN") = NAN;
|
||||||
|
m.attr("PINF") = std::numeric_limits<double>::infinity();
|
||||||
|
m.attr("PZERO") = 0.0;
|
||||||
|
m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
|
||||||
|
m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
|
||||||
|
m.attr("inf") = std::numeric_limits<double>::infinity();
|
||||||
|
m.attr("infty") = std::numeric_limits<double>::infinity();
|
||||||
|
m.attr("nan") = NAN;
|
||||||
|
m.attr("newaxis") = pybind11::none();
|
||||||
|
m.attr("pi") = 3.1415926535897932384626433;
|
||||||
|
}
|
@ -16,6 +16,7 @@ void init_transforms(py::module_&);
|
|||||||
void init_random(py::module_&);
|
void init_random(py::module_&);
|
||||||
void init_fft(py::module_&);
|
void init_fft(py::module_&);
|
||||||
void init_linalg(py::module_&);
|
void init_linalg(py::module_&);
|
||||||
|
void init_constants(py::module_&);
|
||||||
|
|
||||||
PYBIND11_MODULE(core, m) {
|
PYBIND11_MODULE(core, m) {
|
||||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||||
@ -31,5 +32,6 @@ PYBIND11_MODULE(core, m) {
|
|||||||
init_random(m);
|
init_random(m);
|
||||||
init_fft(m);
|
init_fft(m);
|
||||||
init_linalg(m);
|
init_linalg(m);
|
||||||
|
init_constants(m);
|
||||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||||
}
|
}
|
||||||
|
59
python/tests/test_constants.py
Normal file
59
python/tests/test_constants.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestConstants(mlx_tests.MLXTestCase):
|
||||||
|
def test_constants_values(self):
|
||||||
|
# Check if mlx constants match expected values
|
||||||
|
self.assertAlmostEqual(mx.Inf, float("inf"))
|
||||||
|
self.assertAlmostEqual(mx.Infinity, float("inf"))
|
||||||
|
self.assertTrue(np.isnan(mx.NAN))
|
||||||
|
self.assertAlmostEqual(mx.NINF, float("-inf"))
|
||||||
|
self.assertEqual(mx.NZERO, -0.0)
|
||||||
|
self.assertTrue(np.isnan(mx.NaN))
|
||||||
|
self.assertAlmostEqual(mx.PINF, float("inf"))
|
||||||
|
self.assertEqual(mx.PZERO, 0.0)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
mx.e, 2.71828182845904523536028747135266249775724709369995
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
mx.euler_gamma, 0.5772156649015328606065120900824024310421
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(mx.inf, float("inf"))
|
||||||
|
self.assertAlmostEqual(mx.infty, float("inf"))
|
||||||
|
self.assertTrue(np.isnan(mx.nan))
|
||||||
|
self.assertIsNone(mx.newaxis)
|
||||||
|
self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433)
|
||||||
|
|
||||||
|
def test_constants_availability(self):
|
||||||
|
# Check if mlx constants are available
|
||||||
|
self.assertTrue(hasattr(mx, "Inf"))
|
||||||
|
self.assertTrue(hasattr(mx, "Infinity"))
|
||||||
|
self.assertTrue(hasattr(mx, "NAN"))
|
||||||
|
self.assertTrue(hasattr(mx, "NINF"))
|
||||||
|
self.assertTrue(hasattr(mx, "NaN"))
|
||||||
|
self.assertTrue(hasattr(mx, "PINF"))
|
||||||
|
self.assertTrue(hasattr(mx, "NZERO"))
|
||||||
|
self.assertTrue(hasattr(mx, "PZERO"))
|
||||||
|
self.assertTrue(hasattr(mx, "e"))
|
||||||
|
self.assertTrue(hasattr(mx, "euler_gamma"))
|
||||||
|
self.assertTrue(hasattr(mx, "inf"))
|
||||||
|
self.assertTrue(hasattr(mx, "infty"))
|
||||||
|
self.assertTrue(hasattr(mx, "nan"))
|
||||||
|
self.assertTrue(hasattr(mx, "newaxis"))
|
||||||
|
self.assertTrue(hasattr(mx, "pi"))
|
||||||
|
|
||||||
|
def test_newaxis_for_reshaping_arrays(self):
|
||||||
|
arr_1d = mx.array([1, 2, 3, 4, 5])
|
||||||
|
arr_2d_column = arr_1d[:, mx.newaxis]
|
||||||
|
expected_result = mx.array([[1], [2], [3], [4], [5]])
|
||||||
|
self.assertTrue(mx.array_equal(arr_2d_column, expected_result))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user