diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index fb91d9044..8ec82d177 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -23,6 +23,11 @@ if(MSVC) target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804) endif() +if(WIN32) + # Export symbols by default to behave like macOS/linux. + set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) +endif() + if(MLX_BUILD_CPU) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) else() diff --git a/mlx/utils.cpp b/mlx/utils.cpp index daa90fea6..6d05ad5f8 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -56,7 +56,10 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) { os << val; } -PrintFormatter global_formatter; +PrintFormatter& get_global_formatter() { + static PrintFormatter formatter; + return formatter; +} Dtype result_type(const std::vector& arrays) { Dtype t = bool_; @@ -171,7 +174,7 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { i = n - num_print - 1; index += s * (n - 2 * num_print - 1); } else if (is_last) { - global_formatter.print(os, a.data()[index]); + get_global_formatter().print(os, a.data()[index]); } else { print_subarray(os, a, index, dim + 1); } @@ -187,7 +190,7 @@ void print_array(std::ostream& os, const array& a) { os << "array("; if (a.ndim() == 0) { auto data = a.data(); - global_formatter.print(os, data[0]); + get_global_formatter().print(os, data[0]); } else { print_subarray(os, a, 0, 0); } diff --git a/mlx/utils.h b/mlx/utils.h index 108fdf203..04f59feaa 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -51,7 +51,7 @@ struct PrintFormatter { bool capitalize_bool{false}; }; -extern PrintFormatter global_formatter; +PrintFormatter& get_global_formatter(); /** The type from promoting the arrays' types with one another. */ inline Dtype result_type(const array& a, const array& b) { diff --git a/python/src/array.cpp b/python/src/array.cpp index e518f2765..017fb6e91 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -85,7 +85,7 @@ class ArrayPythonIterator { void init_array(nb::module_& m) { // Set Python print formatting options - mlx::core::global_formatter.capitalize_bool = true; + get_global_formatter().capitalize_bool = true; // Types nb::class_(