diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 1e1ef8bdb..ef3c6dd2e 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2,14 +2,20 @@ import gc import operator +import os import pickle -import resource +import platform import sys import unittest import weakref from copy import copy, deepcopy from itertools import permutations +if platform.system() == "Windows": + import psutil +else: + import resource + import mlx.core as mx import mlx_tests import numpy as np @@ -1932,7 +1938,11 @@ class TestArray(mlx_tests.MLXTestCase): def test_siblings_without_eval(self): def get_mem(): - return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if platform.system() == "Windows": + process = psutil.Process(os.getpid()) + return process.memory_info().peak_wset + else: + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss key = mx.array([1, 2])