Some fixes to typing (#1371)

* some fixes to typing

* fix module reference

* comment
This commit is contained in:
Awni Hannun
2024-08-28 11:16:19 -07:00
committed by GitHub
parent bd47e1f066
commit 291cf40aca
15 changed files with 152 additions and 145 deletions

View File

@@ -63,7 +63,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
"def norm(a: array, /, ord: Union[None, int, float, str] = None, axis: Union[None, int, list[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Matrix or vector norm.
@@ -74,7 +74,7 @@ void init_linalg(nb::module_& parent_module) {
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
2-norm of ``a.flatten`` will be returned.
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
ord (int, float or str, optional): Order of the norm (see table under ``Notes``).
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
along the given ``axis``. Default: ``None``.
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
@@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"),
R"pbdoc(
The QR factorization of the input matrix.
@@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"),
R"pbdoc(
The Singular Value Decomposition (SVD) of the input matrix.