fix gumbel (#1495)

This commit is contained in:
Awni Hannun 2024-10-17 13:52:39 -07:00 committed by GitHub
parent 343aa46b78
commit 9dd72cd421
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -352,10 +352,10 @@ void init_random(nb::module_& parent_module) {
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a.none() = float32, "dtype"_a.none() = float32,
"stream"_a = nb::none(),
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig( nb::sig(
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"), "def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Union[None, Stream, Device] = None, stream: Optional[array] = None) -> array"),
R"pbdoc( R"pbdoc(
Sample from the standard Gumbel distribution. Sample from the standard Gumbel distribution.
@ -364,11 +364,14 @@ void init_random(nb::module_& parent_module) {
Args: Args:
shape (list(int)): The shape of the output. shape (list(int)): The shape of the output.
dtype (Dtype, optional): The data type of the output.
Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``. key (array, optional): A PRNG key. Default: ``None``.
Returns: Returns:
array: The :class:`array` with shape ``shape`` and array:
distributed according to the Gumbel distribution The :class:`array` with shape ``shape`` and distributed according
to the Gumbel distribution.
)pbdoc"); )pbdoc");
m.def( m.def(
"categorical", "categorical",