From 07f35c9d8a7b1ebea5184ef97d8ca480bd5dab54 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 26 Jan 2024 15:16:46 -0800 Subject: [PATCH] Fix a few issues: docs for flatten, erf, dequantize validation (#560) * doc flatten * erf doc * check values for dequantize * format --- mlx/ops.cpp | 10 ++++++++++ python/src/ops.cpp | 15 ++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2012613e0..f1f83a6b7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2993,6 +2993,16 @@ array dequantize( int group_size /* = 64 */, int bits /* = 4 */, StreamOrDevice s /* = {} */) { + if (bits <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for bits: " << bits; + throw std::invalid_argument(msg.str()); + } + if (group_size <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for group_size: " << group_size; + throw std::invalid_argument(msg.str()); + } if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) { throw std::invalid_argument("[dequantize] Only matrices supported for now"); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index d622dcdf1..10eeac27a 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -78,6 +78,11 @@ void init_ops(py::module_& m) { Flatten an array. + The axes flattened will be between ``start_axis`` and ``end_axis``, + inclusive. Negative axes are supported. After converting negative axis to + positive, axes outside the valid range will be clamped to a valid value, + ``start_axis`` to ``0`` and ``end_axis`` to ``ndim - 1``. + Args: a (array): Input array. start_axis (int, optional): The first dimension to flatten. Defaults to ``0``. @@ -87,6 +92,14 @@ void init_ops(py::module_& m) { Returns: array: The flattened array. + + Example: + >>> a = mx.array([[1, 2], [3, 4]]) + >>> mx.flatten(a) + array([1, 2, 3, 4], dtype=int32) + >>> + >>> mx.flatten(a, start_axis=0, end_axis=-1) + array([1, 2, 3, 4], dtype=int32) )pbdoc"); m.def( "squeeze", @@ -801,7 +814,7 @@ void init_ops(py::module_& m) { Element-wise error function. .. math:: - \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^t e^{-t^2} \, dx + \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \, dt Args: a (array): Input array.