Fix a few issues: docs for flatten, erf, dequantize validation (#560)

* doc flatten

* erf doc

* check values for dequantize

* format
This commit is contained in:
Awni Hannun 2024-01-26 15:16:46 -08:00 committed by GitHub
parent bf17ab5002
commit 07f35c9d8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 1 deletions

View File

@ -2993,6 +2993,16 @@ array dequantize(
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (bits <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits;
throw std::invalid_argument(msg.str());
}
if (group_size <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str());
}
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
throw std::invalid_argument("[dequantize] Only matrices supported for now");
}

View File

@ -78,6 +78,11 @@ void init_ops(py::module_& m) {
Flatten an array.
The axes flattened will be between ``start_axis`` and ``end_axis``,
inclusive. Negative axes are supported. After converting negative axis to
positive, axes outside the valid range will be clamped to a valid value,
``start_axis`` to ``0`` and ``end_axis`` to ``ndim - 1``.
Args:
a (array): Input array.
start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.
@ -87,6 +92,14 @@ void init_ops(py::module_& m) {
Returns:
array: The flattened array.
Example:
>>> a = mx.array([[1, 2], [3, 4]])
>>> mx.flatten(a)
array([1, 2, 3, 4], dtype=int32)
>>>
>>> mx.flatten(a, start_axis=0, end_axis=-1)
array([1, 2, 3, 4], dtype=int32)
)pbdoc");
m.def(
"squeeze",
@ -801,7 +814,7 @@ void init_ops(py::module_& m) {
Element-wise error function.
.. math::
\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^t e^{-t^2} \, dx
\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \, dt
Args:
a (array): Input array.