post nanobind docs fixes and some updates (#889)

* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
This commit is contained in:
Awni Hannun
2024-03-24 15:03:27 -07:00
committed by GitHub
parent be98f4ab6b
commit 1e16331d9c
16 changed files with 185 additions and 118 deletions

View File

@@ -92,6 +92,8 @@ void init_random(nb::module_& parent_module) {
"key"_a,
"num"_a = 2,
"stream"_a = nb::none(),
nb::sig(
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"),
R"pbdoc(
Split a PRNG key into sub keys.
@@ -125,6 +127,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def uniform(low: Union[scalar, array] = 0, high: Union[scalar, array] = 1, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate uniformly distributed random numbers.
@@ -159,6 +163,8 @@ void init_random(nb::module_& parent_module) {
"scale"_a = 1.0,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate normally distributed random numbers.
@@ -190,6 +196,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def randint(low: Union[scalar, array], high: Union[scalar, array], shape: Sequence[int] = [], dtype: Optional[Dtype] = int32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate random integers from the given interval.
@@ -225,6 +233,8 @@ void init_random(nb::module_& parent_module) {
"shape"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def bernoulli(p: Union[scalar, array] = 0.5, shape: Optional[Sequence[int]] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate Bernoulli random values.
@@ -266,6 +276,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate values from a truncated normal distribution.
@@ -298,6 +310,8 @@ void init_random(nb::module_& parent_module) {
"dtype"_a.none() = float32,
"stream"_a = nb::none(),
"key"_a = nb::none(),
nb::sig(
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Sample from the standard Gumbel distribution.
@@ -338,6 +352,8 @@ void init_random(nb::module_& parent_module) {
"num_samples"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def categorical(logits: array, axis: int = -1, shape: Optional[Sequence[int]] = None, num_samples: Optional[int] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Sample from a categorical distribution.