mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some fixes to typing (#1371)
* some fixes to typing * fix module reference * comment
This commit is contained in:
@@ -93,7 +93,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"num"_a = 2,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"),
|
||||
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Split a PRNG key into sub keys.
|
||||
|
||||
@@ -321,7 +321,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate values from a truncated normal distribution.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user