diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index c75171e5..1a86fb9f 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -87,7 +87,8 @@ class CompletionsDatasetCollection: self.collection = data def __getitem__(self, idx: int): - item = next(self.collection) + iteration = iter(self.collection) + item = next(iteration) curr_idx = idx @@ -97,7 +98,7 @@ class CompletionsDatasetCollection: return item[curr_idx] else: curr_idx -= len(item) - item = next(self.collection) + item = next(iteration) except StopIteration: raise IndexError(idx)