diff --git a/mlx/random.h b/mlx/random.h index 360bdbdb1..ab75eb488 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include "mlx/array.h" @@ -18,12 +19,18 @@ class KeySequence { // static default static KeySequence& default_() { - static KeySequence ks(0); + static KeySequence ks(get_current_time_seed()); return ks; } private: array key_; + static uint64_t get_current_time_seed() { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + } }; /** Get a PRNG key from a seed. */