diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 289bd6053..46cc3df33 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -124,7 +124,7 @@ std::vector Primitive::vjp( const std::vector&, const std::vector&) { std::ostringstream msg; - msg << "[Primitive::vip] Not implemented for "; + msg << "[Primitive::vjp] Not implemented for "; print(msg); msg << "."; throw std::invalid_argument(msg.str()); diff --git a/mlx/random.h b/mlx/random.h index 1b97413ee..fb1a76bc0 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -7,6 +7,7 @@ #include "mlx/array.h" #include "mlx/stream.h" +#include "mlx/utils.h" namespace mlx::core::random {