mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 07:01:13 +08:00
Fix a few issues: docs for flatten, erf, dequantize validation (#560)
* doc flatten * erf doc * check values for dequantize * format
This commit is contained in:
parent
bf17ab5002
commit
07f35c9d8a
10
mlx/ops.cpp
10
mlx/ops.cpp
@ -2993,6 +2993,16 @@ array dequantize(
|
|||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (bits <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dequantize] Invalid value for bits: " << bits;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if (group_size <= 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
|
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
|
||||||
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
||||||
}
|
}
|
||||||
|
@ -78,6 +78,11 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Flatten an array.
|
Flatten an array.
|
||||||
|
|
||||||
|
The axes flattened will be between ``start_axis`` and ``end_axis``,
|
||||||
|
inclusive. Negative axes are supported. After converting negative axis to
|
||||||
|
positive, axes outside the valid range will be clamped to a valid value,
|
||||||
|
``start_axis`` to ``0`` and ``end_axis`` to ``ndim - 1``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.
|
start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.
|
||||||
@ -87,6 +92,14 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The flattened array.
|
array: The flattened array.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> a = mx.array([[1, 2], [3, 4]])
|
||||||
|
>>> mx.flatten(a)
|
||||||
|
array([1, 2, 3, 4], dtype=int32)
|
||||||
|
>>>
|
||||||
|
>>> mx.flatten(a, start_axis=0, end_axis=-1)
|
||||||
|
array([1, 2, 3, 4], dtype=int32)
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"squeeze",
|
"squeeze",
|
||||||
@ -801,7 +814,7 @@ void init_ops(py::module_& m) {
|
|||||||
Element-wise error function.
|
Element-wise error function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^t e^{-t^2} \, dx
|
\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \, dt
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
|
Loading…
Reference in New Issue
Block a user