mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates * one more doc nit * fix for stubs and latex
This commit is contained in:
@@ -92,6 +92,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"key"_a,
|
||||
"num"_a = 2,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"),
|
||||
R"pbdoc(
|
||||
Split a PRNG key into sub keys.
|
||||
|
||||
@@ -125,6 +127,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def uniform(low: Union[scalar, array] = 0, high: Union[scalar, array] = 1, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate uniformly distributed random numbers.
|
||||
|
||||
@@ -159,6 +163,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"scale"_a = 1.0,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
@@ -190,6 +196,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"dtype"_a.none() = int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def randint(low: Union[scalar, array], high: Union[scalar, array], shape: Sequence[int] = [], dtype: Optional[Dtype] = int32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate random integers from the given interval.
|
||||
|
||||
@@ -225,6 +233,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"shape"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bernoulli(p: Union[scalar, array] = 0.5, shape: Optional[Sequence[int]] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate Bernoulli random values.
|
||||
|
||||
@@ -266,6 +276,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"dtype"_a.none() = float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate values from a truncated normal distribution.
|
||||
|
||||
@@ -298,6 +310,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"dtype"_a.none() = float32,
|
||||
"stream"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Sample from the standard Gumbel distribution.
|
||||
|
||||
@@ -338,6 +352,8 @@ void init_random(nb::module_& parent_module) {
|
||||
"num_samples"_a = nb::none(),
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def categorical(logits: array, axis: int = -1, shape: Optional[Sequence[int]] = None, num_samples: Optional[int] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Sample from a categorical distribution.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user