diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 4599772e..cdda7abd 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -90,7 +90,8 @@ class CompletionsDatasetCollection: self.collection = data def __getitem__(self, idx: int): - item = next(self.collection) + iteration = iter(self.collection) + item = next(iteration) curr_idx = idx @@ -100,7 +101,7 @@ class CompletionsDatasetCollection: return item[curr_idx] else: curr_idx -= len(item) - item = next(self.collection) + item = next(iteration) except StopIteration: raise IndexError(idx)