mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Implements divide for integer types and adds floor_divide op (#228)
* Add floor_divide * Add floor_divide to the tests * Add floor_divide to the docs
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							de892cb66c
						
					
				
				
					commit
					2807c6aff0
				
			| @@ -636,8 +636,7 @@ void init_array(py::module_& m) { | ||||
|           "__floordiv__", | ||||
|           [](const array& a, const ScalarOrArray v) { | ||||
|             auto b = to_array(v, a.dtype()); | ||||
|             auto t = promote_types(a.dtype(), b.dtype()); | ||||
|             return astype(divide(a, b), t); | ||||
|             return floor_divide(a, b); | ||||
|           }, | ||||
|           "other"_a) | ||||
|       .def( | ||||
| @@ -650,8 +649,7 @@ void init_array(py::module_& m) { | ||||
|           "__rfloordiv__", | ||||
|           [](const array& a, const ScalarOrArray v) { | ||||
|             auto b = to_array(v, a.dtype()); | ||||
|             auto t = promote_types(a.dtype(), b.dtype()); | ||||
|             return astype(divide(b, a), t); | ||||
|             return floor_divide(b, a); | ||||
|           }, | ||||
|           "other"_a) | ||||
|       .def( | ||||
|   | ||||
| @@ -303,6 +303,32 @@ void init_ops(py::module_& m) { | ||||
|         Returns: | ||||
|             array: The quotient ``a / b``. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "floor_divide", | ||||
|       [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { | ||||
|         auto [a, b] = to_arrays(a_, b_); | ||||
|         return floor_divide(a, b, s); | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "b"_a, | ||||
|       py::pos_only(), | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|         Element-wise integer division. | ||||
|  | ||||
|         If either array is a floating point type then it is equivalent to | ||||
|         calling :func:`floor` after :func:`divide`. | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array or scalar. | ||||
|             b (array): Input array or scalar. | ||||
|  | ||||
|         Returns: | ||||
|             array: The quotient ``a // b``. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "remainder", | ||||
|       [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { | ||||
|   | ||||
| @@ -115,6 +115,7 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             "subtract", | ||||
|             "multiply", | ||||
|             "divide", | ||||
|             "floor_divide", | ||||
|             "remainder", | ||||
|             "equal", | ||||
|             "not_equal", | ||||
| @@ -1096,6 +1097,7 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             "subtract", | ||||
|             "multiply", | ||||
|             "divide", | ||||
|             "floor_divide", | ||||
|             "maximum", | ||||
|             "minimum", | ||||
|             "power", | ||||
| @@ -1111,19 +1113,21 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                     "uint32", | ||||
|                     "uint64", | ||||
|                 ] | ||||
|  | ||||
|                 float_dtypes = ["float16", "float32"] | ||||
|  | ||||
|                 dtypes = ( | ||||
|                     float_dtypes | ||||
|                     if op in ("divide", "power") | ||||
|                     else (int_dtypes + float_dtypes) | ||||
|                 ) | ||||
|                 dtypes = { | ||||
|                     "divide": float_dtypes, | ||||
|                     "power": float_dtypes, | ||||
|                     "floor_divide": ["float32"] + int_dtypes, | ||||
|                 } | ||||
|                 dtypes = dtypes.get(op, int_dtypes + float_dtypes) | ||||
|  | ||||
|                 for dtype in dtypes: | ||||
|                     atol = 1e-3 if dtype == "float16" else 1e-6 | ||||
|                     with self.subTest(dtype=dtype): | ||||
|                         x1_ = x1.astype(getattr(np, dtype)) | ||||
|                         x2_ = x2.astype(getattr(np, dtype)) | ||||
|                         m = 10 if dtype in int_dtypes else 1 | ||||
|                         x1_ = (x1 * m).astype(getattr(np, dtype)) | ||||
|                         x2_ = (x2 * m).astype(getattr(np, dtype)) | ||||
|                         y1_ = mx.array(x1_) | ||||
|                         y2_ = mx.array(x2_) | ||||
|                         test_ops( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user