fix gumbel (#1495)

This commit is contained in:
Awni Hannun 2024-10-17 13:52:39 -07:00 committed by GitHub
parent 343aa46b78
commit 9dd72cd421
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -352,10 +352,10 @@ void init_random(nb::module_& parent_module) {
},
"shape"_a = std::vector<int>{},
"dtype"_a.none() = float32,
"stream"_a = nb::none(),
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"),
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Union[None, Stream, Device] = None, stream: Optional[array] = None) -> array"),
R"pbdoc(
Sample from the standard Gumbel distribution.
@ -364,11 +364,14 @@ void init_random(nb::module_& parent_module) {
Args:
shape (list(int)): The shape of the output.
dtype (Dtype, optional): The data type of the output.
Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The :class:`array` with shape ``shape`` and
distributed according to the Gumbel distribution
array:
The :class:`array` with shape ``shape`` and distributed according
to the Gumbel distribution.
)pbdoc");
m.def(
"categorical",