mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	feat: Add numpy constants (#428)
* add numpy constants * feat: add unittests * add newaxis * add test for newaxis transformation * refactor
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							c92a134b0d
						
					
				
				
					commit
					975e265f74
				
			@@ -12,6 +12,7 @@ pybind11_add_module(
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										24
									
								
								python/src/constants.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								python/src/constants.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,24 @@
 | 
			
		||||
// init_constants.cpp
 | 
			
		||||
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
#include <limits>
 | 
			
		||||
 | 
			
		||||
namespace py = pybind11;
 | 
			
		||||
 | 
			
		||||
void init_constants(py::module_& m) {
 | 
			
		||||
  m.attr("Inf") = std::numeric_limits<double>::infinity();
 | 
			
		||||
  m.attr("Infinity") = std::numeric_limits<double>::infinity();
 | 
			
		||||
  m.attr("NAN") = NAN;
 | 
			
		||||
  m.attr("NINF") = -std::numeric_limits<double>::infinity();
 | 
			
		||||
  m.attr("NZERO") = -0.0;
 | 
			
		||||
  m.attr("NaN") = NAN;
 | 
			
		||||
  m.attr("PINF") = std::numeric_limits<double>::infinity();
 | 
			
		||||
  m.attr("PZERO") = 0.0;
 | 
			
		||||
  m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
 | 
			
		||||
  m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
 | 
			
		||||
  m.attr("inf") = std::numeric_limits<double>::infinity();
 | 
			
		||||
  m.attr("infty") = std::numeric_limits<double>::infinity();
 | 
			
		||||
  m.attr("nan") = NAN;
 | 
			
		||||
  m.attr("newaxis") = pybind11::none();
 | 
			
		||||
  m.attr("pi") = 3.1415926535897932384626433;
 | 
			
		||||
}
 | 
			
		||||
@@ -16,6 +16,7 @@ void init_transforms(py::module_&);
 | 
			
		||||
void init_random(py::module_&);
 | 
			
		||||
void init_fft(py::module_&);
 | 
			
		||||
void init_linalg(py::module_&);
 | 
			
		||||
void init_constants(py::module_&);
 | 
			
		||||
 | 
			
		||||
PYBIND11_MODULE(core, m) {
 | 
			
		||||
  m.doc() = "mlx: A framework for machine learning on Apple silicon.";
 | 
			
		||||
@@ -31,5 +32,6 @@ PYBIND11_MODULE(core, m) {
 | 
			
		||||
  init_random(m);
 | 
			
		||||
  init_fft(m);
 | 
			
		||||
  init_linalg(m);
 | 
			
		||||
  init_constants(m);
 | 
			
		||||
  m.attr("__version__") = TOSTRING(_VERSION_);
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user