Skip to content

Datasets

Dataset utilities for loading, tokenizing, and packing sequences for causal language model training.

Related documentation:

Fast HuggingFace Loader

forgather.ml.datasets.fast_hf_loader.FastDatasetLoaderSimple

Fast HuggingFace dataset loader backed by an Arrow file index.

On the first call for a given dataset/split combination the loader downloads (or locates) the dataset via the HuggingFace datasets library, records the paths and per-file example counts of the underlying Arrow cache files in a compact JSON index, and returns a ComposableIterableDataset wrapping an ArrowBackend. All subsequent calls for the same configuration load in milliseconds by reading the index directly.

Both HuggingFace Hub datasets and locally saved datasets (produced by Dataset.save_to_disk()) are supported.

Parameters:

Name Type Description Default
index_dir str

Directory in which the JSON index files are stored. Defaults to ~/.cache/fast_hf_indexes_simple.

None

Examples:

>>> loader = FastDatasetLoaderSimple()
>>> ds = loader.load_iterable("allenai/c4", name="en", split="train")
>>> ds = ds.shuffle(seed=42).shard(num_shards=4, index=0)
>>> for example in ds:
...     pass
Source code in src/forgather/ml/datasets/fast_hf_loader.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
class FastDatasetLoaderSimple:
    """
    Fast HuggingFace dataset loader backed by an Arrow file index.

    On the first call for a given dataset/split combination the loader
    downloads (or locates) the dataset via the HuggingFace ``datasets``
    library, records the paths and per-file example counts of the
    underlying Arrow cache files in a compact JSON index, and returns
    a `ComposableIterableDataset` wrapping an `ArrowBackend`. All
    subsequent calls for the same configuration load in milliseconds
    by reading the index directly.

    Both HuggingFace Hub datasets and locally saved datasets (produced
    by ``Dataset.save_to_disk()``) are supported.

    Parameters
    ----------
    index_dir : str, optional
        Directory in which the JSON index files are stored. Defaults to
        ``~/.cache/fast_hf_indexes_simple``.

    Examples
    --------
    >>> loader = FastDatasetLoaderSimple()
    >>> ds = loader.load_iterable("allenai/c4", name="en", split="train")
    >>> ds = ds.shuffle(seed=42).shard(num_shards=4, index=0)
    >>> for example in ds:
    ...     pass
    """

    def __init__(self, index_dir: Optional[str] = None):
        if index_dir is None:
            index_dir = os.path.expanduser("~/.cache/fast_hf_indexes_simple")

        self.index_dir = Path(index_dir)
        self.index_dir.mkdir(parents=True, exist_ok=True)

    def _get_config_hash(
        self,
        path: str,
        name: Optional[str] = None,
        split: Optional[str] = None,
        data_files: Optional[Union[str, list]] = None,
        revision: Optional[str] = None,
        **kwargs,
    ) -> str:
        config = {
            "path": path,
            "name": name,
            "split": split,
            "data_files": data_files,
            "revision": revision,
        }
        config_str = json.dumps(config, sort_keys=True)
        return hashlib.sha256(config_str.encode()).hexdigest()[:16]

    def _get_index_file(self, config_hash: str) -> Path:
        return self.index_dir / f"{config_hash}.json"

    def _get_arrow_files(self, dataset_obj: Dataset) -> Optional[list]:
        """Get Arrow file paths from dataset."""
        if hasattr(dataset_obj, "cache_files") and dataset_obj.cache_files:
            return [cf["filename"] for cf in dataset_obj.cache_files]
        if hasattr(dataset_obj, "_data_files") and dataset_obj._data_files:
            return [df["filename"] for df in dataset_obj._data_files]
        return None

    def _get_file_lengths_from_metadata(
        self, arrow_files: List[str], split: str
    ) -> Optional[List[int]]:
        """
        Try to extract file lengths from HuggingFace's ``dataset_info.json``.

        Avoids opening each Arrow file individually to read metadata. For
        datasets with thousands of files this is significantly faster.
        Returns ``None`` on any mismatch / missing data — caller falls
        back to opening files individually.
        """
        if not arrow_files:
            return None

        try:
            cache_dir = Path(arrow_files[0]).parent
            dataset_info_path = cache_dir / "dataset_info.json"
            if not dataset_info_path.exists():
                return None

            with open(dataset_info_path, "r") as f:
                dataset_info = json.load(f)

            splits = dataset_info.get("splits", {})
            if split not in splits:
                return None

            shard_lengths = splits[split].get("shard_lengths", [])

            if len(shard_lengths) != len(arrow_files):
                logger.warning(
                    f"Shard count mismatch: dataset_info.json has "
                    f"{len(shard_lengths)} shards, but found "
                    f"{len(arrow_files)} Arrow files. Falling back to "
                    f"file-by-file indexing."
                )
                return None

            if not shard_lengths:
                return None

            logger.info(
                f"Loaded file lengths from dataset_info.json: "
                f"{len(shard_lengths)} files, "
                f"{sum(shard_lengths):,} total examples"
            )
            return shard_lengths

        except Exception as e:
            logger.debug(
                f"Could not load file lengths from dataset_info.json: {e}. "
                f"Falling back to file-by-file indexing."
            )
            return None

    def _is_saved_dataset_path(self, path: str) -> bool:
        """
        Check if path is a local directory containing a saved dataset.

        Multi-split form has ``dataset_dict.json`` at the root and
        per-split subdirectories. Single-split form has ``state.json``
        at the root.
        """
        if not path:
            return False

        dataset_path = Path(path)
        if not dataset_path.is_dir():
            return False

        if (dataset_path / "dataset_dict.json").exists():
            return True
        if (dataset_path / "state.json").exists():
            return True

        return False

    def _load_saved_dataset(
        self,
        path: str,
        split: str,
        force_reindex: bool = False,
        length_estimate: str = "dynamic",
        reset_length_on_iter: bool = False,
    ) -> Optional[ComposableIterableDataset]:
        """Load a saved dataset directly from disk."""
        dataset_path = Path(path)

        # Determine split directory.
        dataset_dict_path = dataset_path / "dataset_dict.json"
        if dataset_dict_path.exists():
            with open(dataset_dict_path, "r") as f:
                dataset_dict = json.load(f)
            available_splits = dataset_dict.get("splits", [])
            if split not in available_splits:
                logger.warning(
                    f"Split '{split}' not found in saved dataset. "
                    f"Available splits: {available_splits}"
                )
                return None
            split_dir = dataset_path / split
        else:
            split_dir = dataset_path

        # Read state.json to get data files.
        state_path = split_dir / "state.json"
        if not state_path.exists():
            logger.warning(f"state.json not found in {split_dir}")
            return None

        with open(state_path, "r") as f:
            state = json.load(f)

        data_files = state.get("_data_files", [])
        if not data_files:
            logger.warning(f"No data files listed in {state_path}")
            return None

        # Build full paths to Arrow files.
        arrow_files = [
            str(split_dir / df["filename"])
            for df in data_files
            if df.get("filename", "").endswith(".arrow")
        ]
        if not arrow_files:
            logger.warning(f"No Arrow files found in {split_dir}")
            return None

        missing = [f for f in arrow_files if not Path(f).exists()]
        if missing:
            logger.warning(f"Missing Arrow files: {missing[:5]}...")
            return None

        num_files = len(arrow_files)
        logger.info(f"Found saved dataset with {num_files} Arrow file(s)")

        # Check for cached index.
        config_hash = self._get_config_hash(path, split=split)
        if not force_reindex:
            index_data = self._load_index(config_hash)
            if index_data is not None:
                cached_files = index_data.get("arrow_files", [])
                if cached_files == arrow_files:
                    logger.info("Loading from cached index")
                    file_lengths = index_data.get("file_lengths")
                    return _make_dataset(
                        arrow_files,
                        file_lengths,
                        length_estimate=length_estimate,
                        reset_length_on_iter=reset_length_on_iter,
                    )

        # Get file lengths — try dataset_info.json first.
        file_lengths = self._get_file_lengths_from_metadata(arrow_files, split)

        if file_lengths is None:
            logger.info("Computing per-file example counts...")
            file_lengths = []

            use_progress = HAS_TQDM and sys.stderr.isatty()
            iterator = (
                tqdm(arrow_files, desc="Indexing files", unit="file")
                if use_progress
                else arrow_files
            )

            for arrow_file in iterator:
                ds_file = Dataset.from_file(arrow_file)
                file_lengths.append(len(ds_file))

        total_examples = sum(file_lengths)
        logger.info(f"Total examples: {total_examples:,}")

        metadata = {
            "dataset_path": path,
            "split": split,
            "source": "saved_dataset",
            "num_arrow_files": num_files,
            "total_examples": total_examples,
        }
        self._save_index(config_hash, arrow_files, file_lengths, metadata)

        return _make_dataset(
            arrow_files,
            file_lengths,
            length_estimate=length_estimate,
            reset_length_on_iter=reset_length_on_iter,
        )

    def _save_index(
        self,
        config_hash: str,
        arrow_files: list,
        file_lengths: list,
        metadata: Dict[str, Any],
    ):
        index_data = {
            "version": METADATA_VERSION,
            "arrow_files": arrow_files,
            "file_lengths": file_lengths,
            "metadata": metadata,
            "indexed_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        }

        index_file = self._get_index_file(config_hash)
        with open(index_file, "w") as f:
            json.dump(index_data, f, indent=2)

    def _load_index(self, config_hash: str) -> Optional[Dict[str, Any]]:
        index_file = self._get_index_file(config_hash)
        if not index_file.exists():
            return None

        with open(index_file, "r") as f:
            index_data = json.load(f)

        # Force reindex on version mismatch.
        stored_version = index_data.get("version", 1)
        if stored_version != METADATA_VERSION:
            logger.info(
                f"Index version mismatch (stored: v{stored_version}, "
                f"current: v{METADATA_VERSION}). Forcing reindex..."
            )
            return None

        return index_data

    def load_iterable(
        self,
        path: str,
        name: Optional[str] = None,
        split: Optional[str] = None,
        data_files: Optional[Union[str, list]] = None,
        revision: Optional[str] = None,
        force_reindex: bool = False,
        num_proc: Optional[int] = None,
        length_estimate: str = "dynamic",
        reset_length_on_iter: bool = False,
        **load_dataset_kwargs,
    ) -> ComposableIterableDataset:
        """
        Load a dataset as a `ComposableIterableDataset` over an
        `ArrowBackend`.

        Parameters
        ----------
        path : str
            HuggingFace Hub identifier or a local saved-dataset path.
        name : str, optional
            Dataset configuration name.
        split : str, optional
            Split, with optional slice notation (e.g. ``"train[10000:]"``).
        data_files, revision, num_proc : optional
            Forwarded to ``datasets.load_dataset`` on the slow path.
        force_reindex : bool, optional
            Rebuild the Arrow file index even when a valid cached index
            already exists.
        length_estimate : {"dynamic", "static", "exact"}, optional
            Length-estimation mode for the wrapper. Default ``"dynamic"``.
        reset_length_on_iter : bool, optional
            Reset wrapper length-estimation counters at the start of each
            new iteration. Default ``False``.
        **load_dataset_kwargs
            Forwarded to ``datasets.load_dataset`` on the slow path.

        Returns
        -------
        ComposableIterableDataset
            Wrapper around an `ArrowBackend` ready for shuffling,
            sharding, mapping, and checkpointing.
        """
        # Saved-dataset path?
        if self._is_saved_dataset_path(path):
            logger.info(f"Detected saved dataset at: {path}")
            base_split, slice_start, slice_end = (
                _parse_split_notation(split) if split else (split, None, None)
            )
            effective_split = base_split or "train"

            result = self._load_saved_dataset(
                path=path,
                split=effective_split,
                force_reindex=force_reindex,
                length_estimate=length_estimate,
                reset_length_on_iter=reset_length_on_iter,
            )
            if result is not None:
                if slice_start is not None or slice_end is not None:
                    result = result.slice(slice_start, slice_end)
                return result
            else:
                logger.warning(
                    "Failed to load saved dataset, falling back to load_from_disk"
                )

        # Hub-style path with optional slice notation.
        base_split, slice_start, slice_end = (
            _parse_split_notation(split) if split else (split, None, None)
        )

        config_hash = self._get_config_hash(
            path, name, base_split, data_files, revision
        )
        index_data = self._load_index(config_hash) if not force_reindex else None

        if index_data is not None:
            arrow_files = index_data["arrow_files"]
            file_lengths = index_data.get("file_lengths")

            if all(Path(f).exists() for f in arrow_files):
                start_time = time.time()
                logger.debug(f"Dataset: {path}" + (f"/{name}" if name else ""))
                if split:
                    logger.debug(f"Split: {split}")

                ds = _make_dataset(
                    arrow_files,
                    file_lengths,
                    length_estimate=length_estimate,
                    reset_length_on_iter=reset_length_on_iter,
                    slice_start=slice_start,
                    slice_end=slice_end,
                )

                elapsed = time.time() - start_time
                logger.debug(
                    f"Loaded as IterableDataset in {elapsed:.3f}s "
                    f"Arrow files: {len(arrow_files)} (natural shards)"
                )
                return ds

            else:
                logger.warning("Arrow files missing. Re-indexing...")

        # Slow path: initial load.
        logger.info(
            f"{'Re-indexing' if index_data else 'First-time indexing'} dataset..."
        )
        logger.info(f"Dataset: {path}" + (f"/{name}" if name else ""))
        logger.info("This will be slow, but only happens once...")

        start_time = time.time()
        ds = load_dataset(
            path,
            name=name,
            split=base_split,
            data_files=data_files,
            revision=revision,
            num_proc=num_proc,
            **load_dataset_kwargs,
        )
        load_time = time.time() - start_time
        logger.info(f"Dataset loaded in {load_time:.1f}s")

        arrow_files = self._get_arrow_files(ds)

        if arrow_files:
            num_files = len(arrow_files)
            logger.info(f"Found {num_files} Arrow file(s) in HF cache")

            file_lengths = self._get_file_lengths_from_metadata(arrow_files, base_split)

            if file_lengths is None:
                logger.info("Computing per-file example counts...")
                file_lengths = []

                use_progress = HAS_TQDM and sys.stderr.isatty()
                iterator = (
                    tqdm(arrow_files, desc="Indexing files", unit="file")
                    if use_progress
                    else arrow_files
                )

                for arrow_file in iterator:
                    ds_file = Dataset.from_file(arrow_file)
                    file_lengths.append(len(ds_file))

            total_examples = sum(file_lengths)
            logger.info(f"Total examples: {total_examples:,}")

            metadata = {
                "dataset_path": path,
                "dataset_name": name,
                "split": base_split,
                "load_time": load_time,
                "num_arrow_files": num_files,
                "total_examples": total_examples,
            }
            self._save_index(config_hash, arrow_files, file_lengths, metadata)

            total_size = sum(Path(f).stat().st_size for f in arrow_files)
            size_gb = total_size / (1024**3)
            logger.info(
                f"Index saved: {num_files} Arrow files = {num_files} "
                f"natural shards, Data size: {size_gb:.2f} GB"
            )

            return _make_dataset(
                arrow_files,
                file_lengths,
                length_estimate=length_estimate,
                reset_length_on_iter=reset_length_on_iter,
                slice_start=slice_start,
                slice_end=slice_end,
            )

        else:
            logger.warning("Could not find Arrow files")
            # Fallback: use regular to_iterable_dataset.
            result_ds = ds.to_iterable_dataset(num_shards=1)
            # Note: split-notation slice is not applied to this fallback.
            return result_ds

load_iterable(path, name=None, split=None, data_files=None, revision=None, force_reindex=False, num_proc=None, length_estimate='dynamic', reset_length_on_iter=False, **load_dataset_kwargs)

Load a dataset as a ComposableIterableDataset over an ArrowBackend.

Parameters:

Name Type Description Default
path str

HuggingFace Hub identifier or a local saved-dataset path.

required
name str

Dataset configuration name.

None
split str

Split, with optional slice notation (e.g. "train[10000:]").

None
data_files optional

Forwarded to datasets.load_dataset on the slow path.

None
revision optional

Forwarded to datasets.load_dataset on the slow path.

None
num_proc optional

Forwarded to datasets.load_dataset on the slow path.

None
force_reindex bool

Rebuild the Arrow file index even when a valid cached index already exists.

False
length_estimate (dynamic, static, exact)

Length-estimation mode for the wrapper. Default "dynamic".

"dynamic"
reset_length_on_iter bool

Reset wrapper length-estimation counters at the start of each new iteration. Default False.

False
**load_dataset_kwargs

Forwarded to datasets.load_dataset on the slow path.

{}

Returns:

Type Description
ComposableIterableDataset

Wrapper around an ArrowBackend ready for shuffling, sharding, mapping, and checkpointing.

Source code in src/forgather/ml/datasets/fast_hf_loader.py
def load_iterable(
    self,
    path: str,
    name: Optional[str] = None,
    split: Optional[str] = None,
    data_files: Optional[Union[str, list]] = None,
    revision: Optional[str] = None,
    force_reindex: bool = False,
    num_proc: Optional[int] = None,
    length_estimate: str = "dynamic",
    reset_length_on_iter: bool = False,
    **load_dataset_kwargs,
) -> ComposableIterableDataset:
    """
    Load a dataset as a `ComposableIterableDataset` over an
    `ArrowBackend`.

    Parameters
    ----------
    path : str
        HuggingFace Hub identifier or a local saved-dataset path.
    name : str, optional
        Dataset configuration name.
    split : str, optional
        Split, with optional slice notation (e.g. ``"train[10000:]"``).
    data_files, revision, num_proc : optional
        Forwarded to ``datasets.load_dataset`` on the slow path.
    force_reindex : bool, optional
        Rebuild the Arrow file index even when a valid cached index
        already exists.
    length_estimate : {"dynamic", "static", "exact"}, optional
        Length-estimation mode for the wrapper. Default ``"dynamic"``.
    reset_length_on_iter : bool, optional
        Reset wrapper length-estimation counters at the start of each
        new iteration. Default ``False``.
    **load_dataset_kwargs
        Forwarded to ``datasets.load_dataset`` on the slow path.

    Returns
    -------
    ComposableIterableDataset
        Wrapper around an `ArrowBackend` ready for shuffling,
        sharding, mapping, and checkpointing.
    """
    # Saved-dataset path?
    if self._is_saved_dataset_path(path):
        logger.info(f"Detected saved dataset at: {path}")
        base_split, slice_start, slice_end = (
            _parse_split_notation(split) if split else (split, None, None)
        )
        effective_split = base_split or "train"

        result = self._load_saved_dataset(
            path=path,
            split=effective_split,
            force_reindex=force_reindex,
            length_estimate=length_estimate,
            reset_length_on_iter=reset_length_on_iter,
        )
        if result is not None:
            if slice_start is not None or slice_end is not None:
                result = result.slice(slice_start, slice_end)
            return result
        else:
            logger.warning(
                "Failed to load saved dataset, falling back to load_from_disk"
            )

    # Hub-style path with optional slice notation.
    base_split, slice_start, slice_end = (
        _parse_split_notation(split) if split else (split, None, None)
    )

    config_hash = self._get_config_hash(
        path, name, base_split, data_files, revision
    )
    index_data = self._load_index(config_hash) if not force_reindex else None

    if index_data is not None:
        arrow_files = index_data["arrow_files"]
        file_lengths = index_data.get("file_lengths")

        if all(Path(f).exists() for f in arrow_files):
            start_time = time.time()
            logger.debug(f"Dataset: {path}" + (f"/{name}" if name else ""))
            if split:
                logger.debug(f"Split: {split}")

            ds = _make_dataset(
                arrow_files,
                file_lengths,
                length_estimate=length_estimate,
                reset_length_on_iter=reset_length_on_iter,
                slice_start=slice_start,
                slice_end=slice_end,
            )

            elapsed = time.time() - start_time
            logger.debug(
                f"Loaded as IterableDataset in {elapsed:.3f}s "
                f"Arrow files: {len(arrow_files)} (natural shards)"
            )
            return ds

        else:
            logger.warning("Arrow files missing. Re-indexing...")

    # Slow path: initial load.
    logger.info(
        f"{'Re-indexing' if index_data else 'First-time indexing'} dataset..."
    )
    logger.info(f"Dataset: {path}" + (f"/{name}" if name else ""))
    logger.info("This will be slow, but only happens once...")

    start_time = time.time()
    ds = load_dataset(
        path,
        name=name,
        split=base_split,
        data_files=data_files,
        revision=revision,
        num_proc=num_proc,
        **load_dataset_kwargs,
    )
    load_time = time.time() - start_time
    logger.info(f"Dataset loaded in {load_time:.1f}s")

    arrow_files = self._get_arrow_files(ds)

    if arrow_files:
        num_files = len(arrow_files)
        logger.info(f"Found {num_files} Arrow file(s) in HF cache")

        file_lengths = self._get_file_lengths_from_metadata(arrow_files, base_split)

        if file_lengths is None:
            logger.info("Computing per-file example counts...")
            file_lengths = []

            use_progress = HAS_TQDM and sys.stderr.isatty()
            iterator = (
                tqdm(arrow_files, desc="Indexing files", unit="file")
                if use_progress
                else arrow_files
            )

            for arrow_file in iterator:
                ds_file = Dataset.from_file(arrow_file)
                file_lengths.append(len(ds_file))

        total_examples = sum(file_lengths)
        logger.info(f"Total examples: {total_examples:,}")

        metadata = {
            "dataset_path": path,
            "dataset_name": name,
            "split": base_split,
            "load_time": load_time,
            "num_arrow_files": num_files,
            "total_examples": total_examples,
        }
        self._save_index(config_hash, arrow_files, file_lengths, metadata)

        total_size = sum(Path(f).stat().st_size for f in arrow_files)
        size_gb = total_size / (1024**3)
        logger.info(
            f"Index saved: {num_files} Arrow files = {num_files} "
            f"natural shards, Data size: {size_gb:.2f} GB"
        )

        return _make_dataset(
            arrow_files,
            file_lengths,
            length_estimate=length_estimate,
            reset_length_on_iter=reset_length_on_iter,
            slice_start=slice_start,
            slice_end=slice_end,
        )

    else:
        logger.warning("Could not find Arrow files")
        # Fallback: use regular to_iterable_dataset.
        result_ds = ds.to_iterable_dataset(num_shards=1)
        # Note: split-notation slice is not applied to this fallback.
        return result_ds

forgather.ml.datasets.fast_hf_loader.fast_load_iterable_dataset(path, name=None, split=None, data_files=None, revision=None, force_reindex=False, num_proc=None, index_dir=None, length_estimate='dynamic', reset_length_on_iter=False, **load_dataset_kwargs)

Load a HuggingFace dataset as a fast iterable with sharding and checkpoint support.

Routing
  • If the FORGATHER_DATASET_SERVER environment variable is set to a URL (e.g. http://host:8765), the load is routed transparently through the dataset server and a RemoteBackend-wrapped dataset is returned. The server must have been started with --allow-load. Server-only options (force_reindex, num_proc, index_dir, **load_dataset_kwargs) are not forwarded over the wire and take effect only on the local path.
  • Otherwise, loads locally via FastDatasetLoaderSimple. The first call for a given dataset is slow (it builds an Arrow file index); all subsequent calls are instant.

Parameters:

Name Type Description Default
path str

HuggingFace Hub identifier (e.g. "allenai/c4") or a local path to a dataset saved with Dataset.save_to_disk().

required
name str

Dataset configuration name (e.g. "en" for C4 English).

None
split str

Split to load. Supports HuggingFace slice notation such as "train[10000:]" or "validation[:500]".

None
data_files str or list of str

Specific data files to load (forwarded to load_dataset).

None
revision str

Dataset revision or commit hash (forwarded to load_dataset).

None
force_reindex bool

Rebuild the Arrow file index from scratch (local path only).

False
num_proc int

Number of processes for the initial dataset download/indexing step (local path only).

None
index_dir str

Directory where JSON index files are stored (local path only).

None
length_estimate (dynamic, static, exact)

Length-estimation mode for the wrapper.

"dynamic"
reset_length_on_iter bool

Whether to reset length-estimation counters at the start of each new iteration pass.

False
**load_dataset_kwargs

Extra keyword arguments forwarded to datasets.load_dataset on the initial (slow-path) local load. Not forwarded to the remote server.

{}

Returns:

Type Description
ComposableIterableDataset

Iterable dataset (wrapper over ArrowBackend locally or RemoteBackend when routed through the server) supporting:

  • .shuffle(seed) for backend-level + buffer-level shuffling
  • .shard(num_shards, index) for DDP data partitioning
  • .map(fn) for lazy transformations
  • .slice() / .select() for virtual splits
  • state_dict / load_state_dict for stateful checkpointing

Examples:

>>> ds = fast_load_iterable_dataset("allenai/c4", name="en", split="train")
>>> ds = ds.shuffle(seed=42)
>>> ds = ds.shard(num_shards=world_size, index=rank)
>>> ds = ds.map(tokenize)
>>> for example in ds:
...     pass
Source code in src/forgather/ml/datasets/fast_hf_loader.py
def fast_load_iterable_dataset(
    path: str,
    name: Optional[str] = None,
    split: Optional[str] = None,
    data_files: Optional[Union[str, list]] = None,
    revision: Optional[str] = None,
    force_reindex: bool = False,
    num_proc: Optional[int] = None,
    index_dir: Optional[str] = None,
    length_estimate: str = "dynamic",
    reset_length_on_iter: bool = False,
    **load_dataset_kwargs,
) -> ComposableIterableDataset:
    """
    Load a HuggingFace dataset as a fast iterable with sharding and
    checkpoint support.

    Routing
    -------
    - If the ``FORGATHER_DATASET_SERVER`` environment variable is set
      to a URL (e.g. ``http://host:8765``), the load is routed
      transparently through the dataset server and a
      `RemoteBackend`-wrapped dataset is returned. The server must
      have been started with ``--allow-load``. Server-only options
      (``force_reindex``, ``num_proc``, ``index_dir``,
      ``**load_dataset_kwargs``) are not forwarded over the wire and
      take effect only on the local path.
    - Otherwise, loads locally via `FastDatasetLoaderSimple`. The
      first call for a given dataset is slow (it builds an Arrow
      file index); all subsequent calls are instant.

    Parameters
    ----------
    path : str
        HuggingFace Hub identifier (e.g. ``"allenai/c4"``) **or** a local
        path to a dataset saved with ``Dataset.save_to_disk()``.
    name : str, optional
        Dataset configuration name (e.g. ``"en"`` for C4 English).
    split : str, optional
        Split to load. Supports HuggingFace slice notation such as
        ``"train[10000:]"`` or ``"validation[:500]"``.
    data_files : str or list of str, optional
        Specific data files to load (forwarded to ``load_dataset``).
    revision : str, optional
        Dataset revision or commit hash (forwarded to ``load_dataset``).
    force_reindex : bool, optional
        Rebuild the Arrow file index from scratch (local path only).
    num_proc : int, optional
        Number of processes for the initial dataset download/indexing
        step (local path only).
    index_dir : str, optional
        Directory where JSON index files are stored (local path only).
    length_estimate : {"dynamic", "static", "exact"}, optional
        Length-estimation mode for the wrapper.
    reset_length_on_iter : bool, optional
        Whether to reset length-estimation counters at the start of each
        new iteration pass.
    **load_dataset_kwargs
        Extra keyword arguments forwarded to ``datasets.load_dataset``
        on the initial (slow-path) local load. Not forwarded to the
        remote server.

    Returns
    -------
    ComposableIterableDataset
        Iterable dataset (wrapper over `ArrowBackend` locally or
        `RemoteBackend` when routed through the server) supporting:

        - `.shuffle(seed)` for backend-level + buffer-level shuffling
        - `.shard(num_shards, index)` for DDP data partitioning
        - `.map(fn)` for lazy transformations
        - `.slice()` / `.select()` for virtual splits
        - `state_dict` / `load_state_dict` for stateful checkpointing

    Examples
    --------
    >>> ds = fast_load_iterable_dataset("allenai/c4", name="en", split="train")
    >>> ds = ds.shuffle(seed=42)
    >>> ds = ds.shard(num_shards=world_size, index=rank)
    >>> ds = ds.map(tokenize)
    >>> for example in ds:
    ...     pass
    """
    server_url = os.environ.get(DATASET_SERVER_ENV_VAR)
    if server_url:
        if load_dataset_kwargs:
            logger.warning(
                "Ignoring load_dataset_kwargs %s when routing through "
                "%s — server-only on the local path.",
                list(load_dataset_kwargs.keys()),
                DATASET_SERVER_ENV_VAR,
            )
        if server_url.strip().lower() == DATASET_SERVER_AUTO_SENTINEL:
            # Cluster auto-routing: ask the local forgather_server for
            # a healthy dataset_server. The resolver also handles
            # re-routing on failure during long-running iteration.
            return _auto_load_iterable_dataset(
                path=path,
                name=name,
                split=split,
                data_files=data_files,
                revision=revision,
                length_estimate=length_estimate,
                reset_length_on_iter=reset_length_on_iter,
            )
        return _remote_load_iterable_dataset(
            server_url,
            path=path,
            name=name,
            split=split,
            data_files=data_files,
            revision=revision,
            length_estimate=length_estimate,
            reset_length_on_iter=reset_length_on_iter,
        )
    return _local_load_iterable_dataset(
        path=path,
        name=name,
        split=split,
        data_files=data_files,
        revision=revision,
        force_reindex=force_reindex,
        num_proc=num_proc,
        index_dir=index_dir,
        length_estimate=length_estimate,
        reset_length_on_iter=reset_length_on_iter,
        **load_dataset_kwargs,
    )

Backend abstraction

The loader returns a ComposableIterableDataset wrapped around an ArrowBackend. The same wrapper can sit on top of an InMemoryBackend or a RemoteBackend (network proxy to a Dataset Server) without client code changes.

forgather.ml.datasets.composable_iterable_dataset.ComposableIterableDataset

Bases: IterableDataset

Backend-agnostic iterable dataset wrapper.

Wraps any IterableDatasetBackend. Composable transformations (map, slice, shard, shuffle, …) return new wrapper instances; set_epoch mutates in place (callers re-use the same wrapper instance across epochs). Backend-mutating ops (shuffle, seek) return new backend instances and the wrapper holds a reference to the latest one.

The shard mode parameter that the legacy Arrow class supported is intentionally absent: at this layer sharding is purely logical (compute a contiguous example range; restrict iteration to it). Backends that want to do physical optimizations (e.g. file-level affinity) can do so privately on their own; the wrapper does not surface that distinction.

Multi-worker DataLoader support is built in: when iterated under torch.utils.data.DataLoader(num_workers > 1) each worker takes a contiguous sub-window of the visible range. Per-worker checkpoint state is captured by state_dict and restored by load_state_dict.

Length estimation has three modes (length_estimate_mode):

  • "static" — __len__ always returns the view length (after slice/shard), ignoring map-induced cardinality changes.
  • "dynamic" (default) — progressive ratio-based estimate during the first complete pass, then locked to the exact count via _cached_exact_length once iteration runs to completion.
  • "exact" — alias for "dynamic".

Parameters:

Name Type Description Default
backend IterableDatasetBackend

Underlying storage backend.

required
length_estimate ('dynamic', 'static', 'exact')

Initial length-estimation mode. Default "dynamic".

"dynamic"
reset_length_on_iter bool

If True, reset input/output counters at the start of every new iteration. Default False (counters accumulate across passes).

False
Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
class ComposableIterableDataset(TorchIterableDataset):
    """
    Backend-agnostic iterable dataset wrapper.

    Wraps any `IterableDatasetBackend`. Composable transformations
    (`map`, `slice`, `shard`, `shuffle`, …) return new wrapper
    instances; `set_epoch` mutates in place (callers re-use the same
    wrapper instance across epochs). Backend-mutating ops (`shuffle`,
    `seek`) return new backend instances and the wrapper holds a
    reference to the latest one.

    The shard `mode` parameter that the legacy Arrow class supported
    is intentionally absent: at this layer sharding is purely logical
    (compute a contiguous example range; restrict iteration to it).
    Backends that want to do physical optimizations (e.g. file-level
    affinity) can do so privately on their own; the wrapper does not
    surface that distinction.

    Multi-worker DataLoader support is built in: when iterated under
    `torch.utils.data.DataLoader(num_workers > 1)` each worker takes a
    contiguous sub-window of the visible range. Per-worker checkpoint
    state is captured by `state_dict` and restored by `load_state_dict`.

    Length estimation has three modes (`length_estimate_mode`):

    - ``"static"`` — `__len__` always returns the view length (after
      slice/shard), ignoring map-induced cardinality changes.
    - ``"dynamic"`` (default) — progressive ratio-based estimate during
      the first complete pass, then locked to the exact count via
      ``_cached_exact_length`` once iteration runs to completion.
    - ``"exact"`` — alias for ``"dynamic"``.

    Parameters
    ----------
    backend : IterableDatasetBackend
        Underlying storage backend.
    length_estimate : {"dynamic", "static", "exact"}, optional
        Initial length-estimation mode. Default ``"dynamic"``.
    reset_length_on_iter : bool, optional
        If ``True``, reset input/output counters at the start of every
        new iteration. Default ``False`` (counters accumulate across
        passes).
    """

    def __init__(
        self,
        backend: IterableDatasetBackend,
        length_estimate: str = "dynamic",
        reset_length_on_iter: bool = False,
    ):
        if length_estimate not in _LENGTH_MODES:
            raise ValueError(
                f"Invalid length_estimate mode: {length_estimate!r}. "
                f"Must be one of {_LENGTH_MODES}."
            )

        self._backend = backend
        # The "natural" backend for this wrapper: the state set_epoch(0)
        # should restore. At construction it's just the input backend;
        # after `shuffle(seed=...)` returns a new wrapper whose
        # natural state is the seed-shuffled backend (so set_epoch(0)
        # then re-derives effective_seed = base_seed + 0 = base_seed,
        # matching the wrapper's intended baseline).
        self._natural_backend = backend

        # Slice (virtual split) — absolute indices in backend space.
        self._split_start_idx: Optional[int] = None
        self._split_end_idx: Optional[int] = None

        # Shard — implemented logically as a slice computed from
        # (num_shards, shard_index) at construction time.
        self._shard: Optional[Tuple[int, int]] = None  # (num, idx) for repr

        # Shuffle state.
        self._base_shuffle_seed: Optional[int] = None
        self._epoch: int = 0
        self._shuffle_buffer_size: Optional[int] = None

        # Map chain — list of dicts so we can compose multiple maps.
        # Each entry: {"fn", "batched", "batch_size", "drop_last_batch",
        #              "with_indices", "input_columns", "remove_columns",
        #              "fn_kwargs"}
        self._maps: List[Dict[str, Any]] = []

        # Length estimation state.
        self.length_estimate_mode: str = length_estimate
        self._reset_length_on_iter: bool = reset_length_on_iter
        self._input_count: int = 0
        self._output_count: int = 0
        self._cached_exact_length: Optional[int] = None
        self._length_invalidated: bool = False
        self._current_batch_buffer_size: int = 0
        # Set by load_state_dict so the next __iter__ honours the
        # restored backend cursor and counts instead of resetting them.
        self._restored_from_checkpoint: bool = False

    # ----- Construction helpers -----

    def _clone(self, **overrides) -> "ComposableIterableDataset":
        """Return a shallow copy with overrides applied to instance attrs."""
        new = ComposableIterableDataset.__new__(ComposableIterableDataset)
        new._backend = overrides.get("backend", self._backend)
        new._natural_backend = overrides.get("natural_backend", self._natural_backend)
        new._split_start_idx = overrides.get("slice_start", self._split_start_idx)
        new._split_end_idx = overrides.get("slice_end", self._split_end_idx)
        new._shard = overrides.get("shard", self._shard)
        new._base_shuffle_seed = overrides.get(
            "base_shuffle_seed", self._base_shuffle_seed
        )
        new._epoch = overrides.get("epoch", self._epoch)
        new._shuffle_buffer_size = overrides.get(
            "shuffle_buffer_size", self._shuffle_buffer_size
        )
        new._maps = overrides.get("maps", list(self._maps))
        new.length_estimate_mode = overrides.get(
            "length_estimate_mode", self.length_estimate_mode
        )
        new._reset_length_on_iter = overrides.get(
            "reset_length_on_iter", self._reset_length_on_iter
        )
        new._input_count = overrides.get("input_count", self._input_count)
        new._output_count = overrides.get("output_count", self._output_count)
        new._cached_exact_length = overrides.get(
            "cached_exact_length", self._cached_exact_length
        )
        new._length_invalidated = overrides.get(
            "length_invalidated", self._length_invalidated
        )
        new._current_batch_buffer_size = overrides.get(
            "current_batch_buffer_size", self._current_batch_buffer_size
        )
        new._restored_from_checkpoint = overrides.get(
            "restored_from_checkpoint", self._restored_from_checkpoint
        )
        return new

    # ----- Backend metadata pass-through -----

    @property
    def backend(self) -> IterableDatasetBackend:
        return self._backend

    @property
    def column_names(self) -> Optional[List[str]]:
        return getattr(self._backend, "column_names", None)

    @property
    def features(self):
        return getattr(self._backend, "features", None)

    @property
    def n_shards(self) -> int:
        n = getattr(self._backend, "n_shards", None)
        return n if n is not None else 1

    @property
    def _shuffle_seed(self) -> Optional[int]:
        """
        Effective shuffle seed currently in use (``base_seed + epoch``
        when both are set, just ``epoch`` if epoch>0 with no base seed,
        else the base seed if any). Exposed for introspection / parity
        with the legacy class — internal logic computes it on the fly
        via `_effective_buffer_seed`.
        """
        if self._base_shuffle_seed is not None:
            return self._base_shuffle_seed + self._epoch
        if self._epoch > 0:
            return self._epoch
        return None

    # ----- Effective view bounds in backend space -----

    def _view_bounds(self) -> Tuple[int, int]:
        """
        (start, end) in backend space after slice + shard. End exclusive.
        """
        backend_len = len(self._backend)
        start = self._split_start_idx if self._split_start_idx is not None else 0
        end = self._split_end_idx if self._split_end_idx is not None else backend_len
        return start, end

    @staticmethod
    def _get_worker_info() -> Tuple[int, int]:
        info = torch.utils.data.get_worker_info()
        if info is None:
            return 0, 1
        return int(info.id), int(info.num_workers)

    def _worker_view_bounds(self) -> Tuple[int, int]:
        """
        (start, end) restricted to the current worker's slice of the
        view. Identical to `_view_bounds()` outside DataLoader workers
        or when num_workers <= 1.
        """
        start, end = self._view_bounds()
        worker_id, num_workers = self._get_worker_info()
        if num_workers <= 1:
            return start, end
        total = end - start
        if total <= 0:
            return start, end
        per_worker = int(math.ceil(total / num_workers))
        worker_start = start + worker_id * per_worker
        worker_end = min(worker_start + per_worker, end)
        if worker_start >= end:
            # This worker has no work — yield an empty range.
            return end, end
        return worker_start, worker_end

    # ----- Length -----

    def __len__(self) -> int:
        start, end = self._view_bounds()
        base = max(0, end - start)
        if self.length_estimate_mode == "static":
            return base
        if self._cached_exact_length is not None:
            return self._cached_exact_length
        if self._output_count > 0 and self._input_count > 0:
            ratio = self._output_count / self._input_count
            return int(base * ratio)
        return base

    def set_length_estimate_mode(self, mode: str) -> None:
        if mode not in _LENGTH_MODES:
            raise ValueError(f"Invalid mode: {mode!r}. Must be one of {_LENGTH_MODES}.")
        self.length_estimate_mode = mode

    def get_length_stats(self) -> Dict[str, Any]:
        start, end = self._view_bounds()
        base = max(0, end - start)
        ratio: Optional[float]
        if self._output_count > 0 and self._input_count > 0:
            ratio = self._output_count / self._input_count
        else:
            ratio = None
        return {
            "mode": self.length_estimate_mode,
            "original_length": base,
            "input_count": self._input_count,
            "output_count": self._output_count,
            "cached_exact": self._cached_exact_length,
            "ratio": ratio,
            "invalidated": self._length_invalidated,
            "batch_buffer_size": self._current_batch_buffer_size,
            "reset_on_iter": self._reset_length_on_iter,
            "current_estimate": len(self),
        }

    # ----- shuffle / set_epoch -----

    def shuffle(
        self,
        seed: Optional[int] = None,
        buffer_size: Optional[int] = 1000,
    ) -> "ComposableIterableDataset":
        """
        Re-permute the underlying example order via the backend and
        configure an example-level reservoir shuffle buffer.

        Length-estimation cache is invalidated; existing input/output
        counts are preserved as a ratio carry-over.
        """
        if seed is None:
            seed = random.randint(0, 2**31 - 1)
        if buffer_size is None or buffer_size <= 0:
            buffer_size = None
        new_backend = self._backend.shuffle(seed)
        return self._clone(
            backend=new_backend,
            # The seed-shuffled backend is the new "natural" baseline:
            # subsequent set_epoch(0) should restore THIS state, not
            # the original-order backend the wrapper was constructed
            # with. set_epoch(N>0) re-derives base_seed+N from here.
            natural_backend=new_backend,
            base_shuffle_seed=seed,
            epoch=0,
            shuffle_buffer_size=buffer_size,
            # Invalidate but preserve ratio.
            cached_exact_length=None,
            length_invalidated=True,
            # Counts intentionally carried over.
        )

    def set_epoch(self, epoch: int) -> None:
        """
        Set the current epoch and re-shuffle the backend if any seed
        is in play. Mutates in place.

        ``set_epoch(0)`` always restores the wrapper's natural
        backend state (the post-construction or post-``shuffle()``
        baseline) — even if a previous ``set_epoch(N>0)`` left the
        backend in an N-shuffled state. Without this, going back to
        epoch 0 would silently reuse the stale epoch-N order.
        """
        self._epoch = epoch
        if epoch == 0:
            # Restore the baseline: either the post-construction
            # backend, or the seed-shuffled backend that shuffle()
            # set as the new natural state.
            self._backend = self._natural_backend
            return
        if self._base_shuffle_seed is not None:
            effective = self._base_shuffle_seed + epoch
        else:
            effective = epoch
        self._backend = self._natural_backend.shuffle(effective)

    # ----- slice / select / shard -----

    def slice(
        self,
        start: Optional[int | float | str] = None,
        end: Optional[int | float | str] = None,
    ) -> "ComposableIterableDataset":
        """Return a view restricted to ``[start, end)``."""

        def parse(idx, total):
            if idx is None:
                return None
            if isinstance(idx, str):
                if idx.endswith("%"):
                    idx = float(idx[:-1]) / 100.0
                else:
                    idx = float(idx)
            if isinstance(idx, float):
                if not 0 <= idx <= 1:
                    raise ValueError(f"Percentage must be in range [0, 1], got {idx}")
                return int(idx * total)
            if isinstance(idx, int):
                if idx < 0:
                    return total + idx
                return idx
            raise ValueError(f"Invalid index type: {type(idx)}")

        cur_start, cur_end = self._view_bounds()
        cur_len = cur_end - cur_start

        rel_start = parse(start, cur_len) if start is not None else 0
        rel_end = parse(end, cur_len) if end is not None else cur_len

        if not 0 <= rel_start <= cur_len:
            raise ValueError(f"Start index {rel_start} out of range [0, {cur_len}]")
        if not 0 <= rel_end <= cur_len:
            raise ValueError(f"End index {rel_end} out of range [0, {cur_len}]")
        if rel_start >= rel_end:
            raise ValueError(f"Start index {rel_start} must be < end index {rel_end}")

        return self._clone(
            slice_start=cur_start + rel_start,
            slice_end=cur_start + rel_end,
            # Different view; counts and cache no longer apply.
            input_count=0,
            output_count=0,
            cached_exact_length=None,
            length_invalidated=False,
        )

    def select(self, indices) -> "ComposableIterableDataset":
        """Contiguous-range select; non-contiguous indices not supported."""
        if hasattr(indices, "tolist"):
            indices = indices.tolist()
        elif not isinstance(indices, list):
            indices = list(indices)
        if not indices:
            raise ValueError("Cannot select from empty indices")
        start = indices[0]
        end = indices[-1] + 1
        if indices != list(range(start, end)):
            raise NotImplementedError(
                "Only contiguous, ordered index sequences are supported."
            )
        return self.slice(start, end)

    def shard(self, num_shards: int, index: int) -> "ComposableIterableDataset":
        """
        Split into ``num_shards`` disjoint slices and return the one
        at ``index``. Logical sharding only — there is no ``mode``
        parameter at this layer; the backend may do whatever physical
        optimization it wants internally.
        """
        if num_shards < 1:
            raise ValueError(f"num_shards must be >= 1, got {num_shards}")
        if not 0 <= index < num_shards:
            raise ValueError(f"index ({index}) must be in [0, {num_shards})")

        cur_start, cur_end = self._view_bounds()
        total = cur_end - cur_start
        per_shard = total // num_shards
        remainder = total % num_shards
        # Distribute remainder examples to first `remainder` shards.
        if index < remainder:
            shard_offset = index * (per_shard + 1)
            shard_size = per_shard + 1
        else:
            shard_offset = index * per_shard + remainder
            shard_size = per_shard

        return self._clone(
            slice_start=cur_start + shard_offset,
            slice_end=cur_start + shard_offset + shard_size,
            shard=(num_shards, index),
            input_count=0,
            output_count=0,
            cached_exact_length=None,
            length_invalidated=False,
        )

    # ----- map / filter -----

    def map(
        self,
        function: Optional[Callable] = None,
        with_indices: bool = False,
        input_columns: Optional[str | List[str]] = None,
        batched: bool = False,
        batch_size: Optional[int] = 1000,
        drop_last_batch: bool = False,
        remove_columns: Optional[str | List[str]] = None,
        fn_kwargs: Optional[Dict[str, Any]] = None,
    ) -> "ComposableIterableDataset":
        """
        Append a map step to the chain. Multiple map calls compose.

        A non-batched function returning ``None`` filters the example
        out (matches the legacy Arrow class behavior).

        Mixed batched / non-batched chains are not supported (raises).
        """
        if function is None:
            function = _identity
        if isinstance(input_columns, str):
            input_columns = [input_columns]
        if isinstance(remove_columns, str):
            remove_columns = [remove_columns]
        if fn_kwargs is None:
            fn_kwargs = {}

        if self._maps:
            existing_batched = self._maps[0]["batched"]
            if existing_batched != batched:
                raise ValueError("Cannot chain maps with different batched modes.")

        new_maps = list(self._maps)
        new_maps.append(
            {
                "fn": function,
                "batched": batched,
                "batch_size": batch_size,
                "drop_last_batch": drop_last_batch,
                "with_indices": with_indices,
                "input_columns": input_columns,
                "remove_columns": remove_columns,
                "fn_kwargs": fn_kwargs,
            }
        )
        return self._clone(
            maps=new_maps,
            input_count=0,
            output_count=0,
            cached_exact_length=None,
            length_invalidated=False,
        )

    def filter(
        self,
        function: Callable,
        with_indices: bool = False,
        input_columns: Optional[str | List[str]] = None,
        fn_kwargs: Optional[Dict[str, Any]] = None,
    ) -> "ComposableIterableDataset":
        """Keep examples where ``function(example)`` returns truthy."""
        if isinstance(input_columns, str):
            input_columns = [input_columns]
        if fn_kwargs is None:
            fn_kwargs = {}

        def _filter_map(example, *args, **kwargs):
            if input_columns is not None:
                fn_input = {c: example[c] for c in input_columns if c in example}
            else:
                fn_input = example
            keep = (
                function(fn_input, *args, **fn_kwargs)
                if (with_indices and args)
                else function(fn_input, **fn_kwargs)
            )
            return example if keep else None

        return self.map(_filter_map, with_indices=with_indices)

    # ----- iteration -----

    def __iter__(self) -> Iterator[Dict]:
        # Decide whether to reset count state at the start of this pass.
        self._maybe_reset_counts()

        start, end = self._worker_view_bounds()

        # Position the backend at our window start unless we're resuming
        # mid-window (after load_state_dict or partial iteration that
        # left the cursor inside [start, end)).
        cur = self._backend.position()
        if cur < start or cur >= end:
            self._backend = self._backend.seek(start)

        # Clear the restored flag — by the time we're iterating we've
        # honored it.
        self._restored_from_checkpoint = False

        # Index passed to map(with_indices=True): the actual position
        # of the first example we're about to yield, NOT the static
        # window start. Critical when resuming mid-window — otherwise
        # `with_indices` indices restart at `start` regardless of
        # where the saved cursor actually is.
        first_idx = self._backend.position()

        gen = self._iter_window(self._backend, start, end)
        if self._shuffle_buffer_size:
            gen = self._reservoir_buffer(
                gen, self._shuffle_buffer_size, self._effective_buffer_seed()
            )
        if self._maps:
            if self._maps[0]["batched"]:
                gen = self._apply_batched_maps(gen, first_idx)
            else:
                gen = self._apply_single_maps(gen, first_idx)
        else:
            gen = self._track_passthrough(gen)

        completed = False
        try:
            for ex in gen:
                yield ex
            completed = True
        finally:
            self._on_iter_done(completed)

    def _maybe_reset_counts(self) -> None:
        # Don't touch state if we just restored from a checkpoint —
        # the saved counts must survive into the next iteration.
        if self._restored_from_checkpoint:
            return
        if self._reset_length_on_iter or self._length_invalidated:
            preserve_ratio = (
                self._length_invalidated
                and self._input_count > 0
                and self._output_count > 0
            )
            if not preserve_ratio:
                self._input_count = 0
                self._output_count = 0
            self._cached_exact_length = None
            self._length_invalidated = False
            self._current_batch_buffer_size = 0
        else:
            # Fresh iteration with no invalidation: reset counts so the
            # next pass tracks itself, but keep the cached exact length
            # (it persists across iterations once observed).
            self._input_count = 0
            self._output_count = 0
            self._current_batch_buffer_size = 0

    def _on_iter_done(self, completed: bool) -> None:
        self._current_batch_buffer_size = 0
        if not completed:
            return
        # Iteration ran to natural end — cache the exact output count
        # for cardinality-changing maps in dynamic/exact mode.
        if self.length_estimate_mode in ("dynamic", "exact"):
            if self._output_count > 0:
                self._cached_exact_length = self._output_count

    def _iter_window(
        self,
        backend: IterableDatasetBackend,
        start: int,
        end: int,
    ) -> Iterator[Dict]:
        """Yield examples from backend in [start, end). Check-then-consume
        to avoid over-fetching past the window."""
        if start >= end:
            return
        it = iter(backend)
        while True:
            if backend.position() >= end:
                return
            try:
                ex = next(it)
            except StopIteration:
                return
            yielded_idx = backend.position() - 1
            if yielded_idx < start:
                continue
            yield ex

    def _effective_buffer_seed(self) -> int:
        if self._base_shuffle_seed is not None:
            return self._base_shuffle_seed + self._epoch
        return self._epoch or 0

    @staticmethod
    def _reservoir_buffer(
        it: Iterator[Dict], buffer_size: int, seed: int
    ) -> Iterator[Dict]:
        rng = random.Random(seed)
        buf: List[Dict] = []
        for ex in it:
            buf.append(ex)
            if len(buf) >= buffer_size:
                break
        if not buf:
            return
        for ex in it:
            idx = rng.randint(0, buffer_size - 1)
            yield buf[idx]
            buf[idx] = ex
        rng.shuffle(buf)
        yield from buf

    def _track_passthrough(self, it: Iterator[Dict]) -> Iterator[Dict]:
        for ex in it:
            self._input_count += 1
            self._output_count += 1
            yield ex

    def _apply_single_maps(self, it: Iterator[Dict], start_idx: int) -> Iterator[Dict]:
        idx = start_idx
        for example in it:
            self._input_count += 1
            current = example
            keep = True
            for spec in self._maps:
                current = self._call_single(spec, current, idx)
                if current is None:
                    keep = False
                    break
            idx += 1
            if keep:
                self._output_count += 1
                yield current

    @staticmethod
    def _call_single(spec: Dict[str, Any], example: Dict, idx: int) -> Optional[Dict]:
        if spec["input_columns"] is not None:
            fn_input = {c: example[c] for c in spec["input_columns"] if c in example}
        else:
            fn_input = example
        if spec["with_indices"]:
            result = spec["fn"](fn_input, idx, **spec["fn_kwargs"])
        else:
            result = spec["fn"](fn_input, **spec["fn_kwargs"])
        if result is None:
            return None
        if not isinstance(result, dict):
            raise ValueError(
                f"Map function must return a dict or None, got {type(result)}"
            )
        merged = example.copy()
        merged.update(result)
        if spec["remove_columns"] is not None:
            for col in spec["remove_columns"]:
                merged.pop(col, None)
        return merged

    def _apply_batched_maps(self, it: Iterator[Dict], start_idx: int) -> Iterator[Dict]:
        """Batch-collect, run all maps in sequence over the batch."""
        batch_size = self._maps[0]["batch_size"] or 1000
        drop_last = self._maps[0]["drop_last_batch"]

        batch: List[Dict] = []
        batch_start = start_idx
        for example in it:
            batch.append(example)
            self._current_batch_buffer_size = len(batch)
            if len(batch) >= batch_size:
                results = self._run_batched_chain(batch, batch_start)
                self._input_count += len(batch)
                self._output_count += len(results)
                self._current_batch_buffer_size = 0
                yield from results
                batch_start += len(batch)
                batch = []
        if batch and not drop_last:
            results = self._run_batched_chain(batch, batch_start)
            self._input_count += len(batch)
            self._output_count += len(results)
            self._current_batch_buffer_size = 0
            yield from results

    def _run_batched_chain(self, batch: List[Dict], batch_start: int) -> List[Dict]:
        examples = batch
        for spec in self._maps:
            examples = self._apply_batched_step(spec, examples, batch_start)
            if not examples:
                return []
        return examples

    @staticmethod
    def _apply_batched_step(
        spec: Dict[str, Any], examples: List[Dict], batch_start: int
    ) -> List[Dict]:
        if not examples:
            return examples
        # Collect into dict-of-lists.
        if spec["input_columns"] is not None:
            cols = spec["input_columns"]
        else:
            cols = sorted({k for ex in examples for k in ex.keys()})
        batch_dict = {c: [ex.get(c) for ex in examples] for c in cols}

        if spec["with_indices"]:
            indices = list(range(batch_start, batch_start + len(examples)))
            result = spec["fn"](batch_dict, indices, **spec["fn_kwargs"])
        else:
            result = spec["fn"](batch_dict, **spec["fn_kwargs"])

        if result is None:
            return []
        if not isinstance(result, dict):
            raise ValueError(
                f"Batched map function must return a dict or None, got {type(result)}"
            )
        # Determine output count from the first list-valued column.
        n_out = 0
        for v in result.values():
            if isinstance(v, list):
                n_out = len(v)
                break

        out: List[Dict] = []
        for i in range(n_out):
            row = {}
            for c, v in result.items():
                row[c] = v[i] if isinstance(v, list) else v
            if i < len(examples):
                merged = examples[i].copy()
                merged.update(row)
            else:
                merged = row
            if spec["remove_columns"] is not None:
                for c in spec["remove_columns"]:
                    merged.pop(c, None)
            out.append(merged)
        return out

    # ----- HF compatibility -----

    def to_hf_iterable(self):
        """
        Wrap this dataset in a HuggingFace ``IterableDataset`` for APIs
        that require one. The returned object exposes ``__len__`` via
        `IterableDatasetWithLength` so it can drive ``torch.DataLoader``;
        the wrapper checkpoint protocol is *not* preserved on the
        returned value.
        """
        from datasets import IterableDataset as HFIterableDataset

        from .iterable_with_length import IterableDatasetWithLength

        def gen():
            yield from self

        return IterableDatasetWithLength(
            HFIterableDataset.from_generator(gen), len(self)
        )

    # ----- checkpoint protocol -----

    def state_dict(self) -> Dict[str, Any]:
        """
        Capture wrapper state plus the backend's flat position.

        The backend's `position()` is in underlying-example space, not
        in user-facing post-slice/shard/map space — that's deliberate
        so resume can call `backend.seek(saved_position)` and continue
        consuming examples regardless of how a map function may have
        changed cardinality.
        """
        backend_state: Dict[str, Any] = {"position": self._backend.position()}
        if hasattr(self._backend, "state_dict") and callable(self._backend.state_dict):
            try:
                backend_state["state_dict"] = self._backend.state_dict()
            except Exception as exc:  # pragma: no cover
                logger.debug(
                    "backend.state_dict() failed; falling back to position only: %s",
                    exc,
                )

        return {
            "wrapper_version": 1,
            "slice_start": self._split_start_idx,
            "slice_end": self._split_end_idx,
            "shard": self._shard,
            "base_shuffle_seed": self._base_shuffle_seed,
            "epoch": self._epoch,
            "shuffle_buffer_size": self._shuffle_buffer_size,
            "n_maps": len(self._maps),
            "maps_batched": (self._maps[0]["batched"] if self._maps else None),
            "length_estimate_mode": self.length_estimate_mode,
            "reset_length_on_iter": self._reset_length_on_iter,
            "input_count": self._input_count,
            "output_count": self._output_count,
            "cached_exact_length": self._cached_exact_length,
            "length_invalidated": self._length_invalidated,
            "backend": backend_state,
        }

    def load_state_dict(self, state: Dict[str, Any]) -> None:
        """
        Restore wrapper state and seek the backend to the saved
        position. Map functions themselves are not serialised — the
        caller must reconstruct the same map chain before calling
        `load_state_dict` (a fingerprint is checked).
        """
        if state.get("wrapper_version") != 1:
            raise ValueError(f"Unknown wrapper_version: {state.get('wrapper_version')}")
        # Cheap shape guard. We deliberately do NOT compare function
        # identities or fingerprints — callable identity is not stable
        # across pickle / fork / process boundaries (e.g.
        # ``functools.partial.__repr__`` embeds the wrapped function's
        # memory address, which differs per process). Trying to match
        # those was producing false-positive failures in multi-worker
        # DataLoader resume.
        saved_n = state.get("n_maps")
        cur_n = len(self._maps)
        if saved_n is not None and saved_n != cur_n:
            raise ValueError(
                f"Map chain length mismatch on load_state_dict: saved "
                f"{saved_n} maps, current wrapper has {cur_n}. "
                "Reconstruct the wrapper with the same map chain before "
                "calling load_state_dict."
            )
        saved_batched = state.get("maps_batched")
        cur_batched = self._maps[0]["batched"] if self._maps else None
        if saved_batched is not None and saved_batched != cur_batched:
            raise ValueError(
                f"Map chain batched-mode mismatch on load_state_dict: "
                f"saved batched={saved_batched}, current batched="
                f"{cur_batched}."
            )
        self._split_start_idx = state.get("slice_start")
        self._split_end_idx = state.get("slice_end")
        self._shard = state.get("shard")
        self._base_shuffle_seed = state.get("base_shuffle_seed")
        self._epoch = state.get("epoch", 0)
        self._shuffle_buffer_size = state.get("shuffle_buffer_size")

        self.length_estimate_mode = state.get(
            "length_estimate_mode", self.length_estimate_mode
        )
        self._reset_length_on_iter = state.get(
            "reset_length_on_iter", self._reset_length_on_iter
        )
        self._input_count = int(state.get("input_count", 0))
        self._output_count = int(state.get("output_count", 0))
        self._cached_exact_length = state.get("cached_exact_length")
        self._length_invalidated = bool(state.get("length_invalidated", False))

        backend_state = state.get("backend", {})
        if (
            "state_dict" in backend_state
            and hasattr(self._backend, "load_state_dict")
            and callable(self._backend.load_state_dict)
        ):
            self._backend.load_state_dict(backend_state["state_dict"])
        else:
            self._backend = self._backend.seek(backend_state["position"])

        # Mark restored so the next iteration honors saved counts.
        self._restored_from_checkpoint = True

    # ----- repr -----

    def __repr__(self) -> str:
        return (
            f"ComposableIterableDataset(backend={type(self._backend).__name__}, "
            f"len={len(self)}, slice=({self._split_start_idx},{self._split_end_idx}), "
            f"shard={self._shard}, n_maps={len(self._maps)}, "
            f"buffer={self._shuffle_buffer_size}, epoch={self._epoch}, "
            f"mode={self.length_estimate_mode})"
        )

shuffle(seed=None, buffer_size=1000)

Re-permute the underlying example order via the backend and configure an example-level reservoir shuffle buffer.

Length-estimation cache is invalidated; existing input/output counts are preserved as a ratio carry-over.

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def shuffle(
    self,
    seed: Optional[int] = None,
    buffer_size: Optional[int] = 1000,
) -> "ComposableIterableDataset":
    """
    Re-permute the underlying example order via the backend and
    configure an example-level reservoir shuffle buffer.

    Length-estimation cache is invalidated; existing input/output
    counts are preserved as a ratio carry-over.
    """
    if seed is None:
        seed = random.randint(0, 2**31 - 1)
    if buffer_size is None or buffer_size <= 0:
        buffer_size = None
    new_backend = self._backend.shuffle(seed)
    return self._clone(
        backend=new_backend,
        # The seed-shuffled backend is the new "natural" baseline:
        # subsequent set_epoch(0) should restore THIS state, not
        # the original-order backend the wrapper was constructed
        # with. set_epoch(N>0) re-derives base_seed+N from here.
        natural_backend=new_backend,
        base_shuffle_seed=seed,
        epoch=0,
        shuffle_buffer_size=buffer_size,
        # Invalidate but preserve ratio.
        cached_exact_length=None,
        length_invalidated=True,
        # Counts intentionally carried over.
    )

set_epoch(epoch)

Set the current epoch and re-shuffle the backend if any seed is in play. Mutates in place.

set_epoch(0) always restores the wrapper's natural backend state (the post-construction or post-shuffle() baseline) — even if a previous set_epoch(N>0) left the backend in an N-shuffled state. Without this, going back to epoch 0 would silently reuse the stale epoch-N order.

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def set_epoch(self, epoch: int) -> None:
    """
    Set the current epoch and re-shuffle the backend if any seed
    is in play. Mutates in place.

    ``set_epoch(0)`` always restores the wrapper's natural
    backend state (the post-construction or post-``shuffle()``
    baseline) — even if a previous ``set_epoch(N>0)`` left the
    backend in an N-shuffled state. Without this, going back to
    epoch 0 would silently reuse the stale epoch-N order.
    """
    self._epoch = epoch
    if epoch == 0:
        # Restore the baseline: either the post-construction
        # backend, or the seed-shuffled backend that shuffle()
        # set as the new natural state.
        self._backend = self._natural_backend
        return
    if self._base_shuffle_seed is not None:
        effective = self._base_shuffle_seed + epoch
    else:
        effective = epoch
    self._backend = self._natural_backend.shuffle(effective)

slice(start=None, end=None)

Return a view restricted to [start, end).

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def slice(
    self,
    start: Optional[int | float | str] = None,
    end: Optional[int | float | str] = None,
) -> "ComposableIterableDataset":
    """Return a view restricted to ``[start, end)``."""

    def parse(idx, total):
        if idx is None:
            return None
        if isinstance(idx, str):
            if idx.endswith("%"):
                idx = float(idx[:-1]) / 100.0
            else:
                idx = float(idx)
        if isinstance(idx, float):
            if not 0 <= idx <= 1:
                raise ValueError(f"Percentage must be in range [0, 1], got {idx}")
            return int(idx * total)
        if isinstance(idx, int):
            if idx < 0:
                return total + idx
            return idx
        raise ValueError(f"Invalid index type: {type(idx)}")

    cur_start, cur_end = self._view_bounds()
    cur_len = cur_end - cur_start

    rel_start = parse(start, cur_len) if start is not None else 0
    rel_end = parse(end, cur_len) if end is not None else cur_len

    if not 0 <= rel_start <= cur_len:
        raise ValueError(f"Start index {rel_start} out of range [0, {cur_len}]")
    if not 0 <= rel_end <= cur_len:
        raise ValueError(f"End index {rel_end} out of range [0, {cur_len}]")
    if rel_start >= rel_end:
        raise ValueError(f"Start index {rel_start} must be < end index {rel_end}")

    return self._clone(
        slice_start=cur_start + rel_start,
        slice_end=cur_start + rel_end,
        # Different view; counts and cache no longer apply.
        input_count=0,
        output_count=0,
        cached_exact_length=None,
        length_invalidated=False,
    )

select(indices)

Contiguous-range select; non-contiguous indices not supported.

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def select(self, indices) -> "ComposableIterableDataset":
    """Contiguous-range select; non-contiguous indices not supported."""
    if hasattr(indices, "tolist"):
        indices = indices.tolist()
    elif not isinstance(indices, list):
        indices = list(indices)
    if not indices:
        raise ValueError("Cannot select from empty indices")
    start = indices[0]
    end = indices[-1] + 1
    if indices != list(range(start, end)):
        raise NotImplementedError(
            "Only contiguous, ordered index sequences are supported."
        )
    return self.slice(start, end)

shard(num_shards, index)

Split into num_shards disjoint slices and return the one at index. Logical sharding only — there is no mode parameter at this layer; the backend may do whatever physical optimization it wants internally.

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def shard(self, num_shards: int, index: int) -> "ComposableIterableDataset":
    """
    Split into ``num_shards`` disjoint slices and return the one
    at ``index``. Logical sharding only — there is no ``mode``
    parameter at this layer; the backend may do whatever physical
    optimization it wants internally.
    """
    if num_shards < 1:
        raise ValueError(f"num_shards must be >= 1, got {num_shards}")
    if not 0 <= index < num_shards:
        raise ValueError(f"index ({index}) must be in [0, {num_shards})")

    cur_start, cur_end = self._view_bounds()
    total = cur_end - cur_start
    per_shard = total // num_shards
    remainder = total % num_shards
    # Distribute remainder examples to first `remainder` shards.
    if index < remainder:
        shard_offset = index * (per_shard + 1)
        shard_size = per_shard + 1
    else:
        shard_offset = index * per_shard + remainder
        shard_size = per_shard

    return self._clone(
        slice_start=cur_start + shard_offset,
        slice_end=cur_start + shard_offset + shard_size,
        shard=(num_shards, index),
        input_count=0,
        output_count=0,
        cached_exact_length=None,
        length_invalidated=False,
    )

map(function=None, with_indices=False, input_columns=None, batched=False, batch_size=1000, drop_last_batch=False, remove_columns=None, fn_kwargs=None)

Append a map step to the chain. Multiple map calls compose.

A non-batched function returning None filters the example out (matches the legacy Arrow class behavior).

Mixed batched / non-batched chains are not supported (raises).

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def map(
    self,
    function: Optional[Callable] = None,
    with_indices: bool = False,
    input_columns: Optional[str | List[str]] = None,
    batched: bool = False,
    batch_size: Optional[int] = 1000,
    drop_last_batch: bool = False,
    remove_columns: Optional[str | List[str]] = None,
    fn_kwargs: Optional[Dict[str, Any]] = None,
) -> "ComposableIterableDataset":
    """
    Append a map step to the chain. Multiple map calls compose.

    A non-batched function returning ``None`` filters the example
    out (matches the legacy Arrow class behavior).

    Mixed batched / non-batched chains are not supported (raises).
    """
    if function is None:
        function = _identity
    if isinstance(input_columns, str):
        input_columns = [input_columns]
    if isinstance(remove_columns, str):
        remove_columns = [remove_columns]
    if fn_kwargs is None:
        fn_kwargs = {}

    if self._maps:
        existing_batched = self._maps[0]["batched"]
        if existing_batched != batched:
            raise ValueError("Cannot chain maps with different batched modes.")

    new_maps = list(self._maps)
    new_maps.append(
        {
            "fn": function,
            "batched": batched,
            "batch_size": batch_size,
            "drop_last_batch": drop_last_batch,
            "with_indices": with_indices,
            "input_columns": input_columns,
            "remove_columns": remove_columns,
            "fn_kwargs": fn_kwargs,
        }
    )
    return self._clone(
        maps=new_maps,
        input_count=0,
        output_count=0,
        cached_exact_length=None,
        length_invalidated=False,
    )

filter(function, with_indices=False, input_columns=None, fn_kwargs=None)

Keep examples where function(example) returns truthy.

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def filter(
    self,
    function: Callable,
    with_indices: bool = False,
    input_columns: Optional[str | List[str]] = None,
    fn_kwargs: Optional[Dict[str, Any]] = None,
) -> "ComposableIterableDataset":
    """Keep examples where ``function(example)`` returns truthy."""
    if isinstance(input_columns, str):
        input_columns = [input_columns]
    if fn_kwargs is None:
        fn_kwargs = {}

    def _filter_map(example, *args, **kwargs):
        if input_columns is not None:
            fn_input = {c: example[c] for c in input_columns if c in example}
        else:
            fn_input = example
        keep = (
            function(fn_input, *args, **fn_kwargs)
            if (with_indices and args)
            else function(fn_input, **fn_kwargs)
        )
        return example if keep else None

    return self.map(_filter_map, with_indices=with_indices)

to_hf_iterable()

Wrap this dataset in a HuggingFace IterableDataset for APIs that require one. The returned object exposes __len__ via IterableDatasetWithLength so it can drive torch.DataLoader; the wrapper checkpoint protocol is not preserved on the returned value.

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def to_hf_iterable(self):
    """
    Wrap this dataset in a HuggingFace ``IterableDataset`` for APIs
    that require one. The returned object exposes ``__len__`` via
    `IterableDatasetWithLength` so it can drive ``torch.DataLoader``;
    the wrapper checkpoint protocol is *not* preserved on the
    returned value.
    """
    from datasets import IterableDataset as HFIterableDataset

    from .iterable_with_length import IterableDatasetWithLength

    def gen():
        yield from self

    return IterableDatasetWithLength(
        HFIterableDataset.from_generator(gen), len(self)
    )

state_dict()

Capture wrapper state plus the backend's flat position.

The backend's position() is in underlying-example space, not in user-facing post-slice/shard/map space — that's deliberate so resume can call backend.seek(saved_position) and continue consuming examples regardless of how a map function may have changed cardinality.

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def state_dict(self) -> Dict[str, Any]:
    """
    Capture wrapper state plus the backend's flat position.

    The backend's `position()` is in underlying-example space, not
    in user-facing post-slice/shard/map space — that's deliberate
    so resume can call `backend.seek(saved_position)` and continue
    consuming examples regardless of how a map function may have
    changed cardinality.
    """
    backend_state: Dict[str, Any] = {"position": self._backend.position()}
    if hasattr(self._backend, "state_dict") and callable(self._backend.state_dict):
        try:
            backend_state["state_dict"] = self._backend.state_dict()
        except Exception as exc:  # pragma: no cover
            logger.debug(
                "backend.state_dict() failed; falling back to position only: %s",
                exc,
            )

    return {
        "wrapper_version": 1,
        "slice_start": self._split_start_idx,
        "slice_end": self._split_end_idx,
        "shard": self._shard,
        "base_shuffle_seed": self._base_shuffle_seed,
        "epoch": self._epoch,
        "shuffle_buffer_size": self._shuffle_buffer_size,
        "n_maps": len(self._maps),
        "maps_batched": (self._maps[0]["batched"] if self._maps else None),
        "length_estimate_mode": self.length_estimate_mode,
        "reset_length_on_iter": self._reset_length_on_iter,
        "input_count": self._input_count,
        "output_count": self._output_count,
        "cached_exact_length": self._cached_exact_length,
        "length_invalidated": self._length_invalidated,
        "backend": backend_state,
    }

load_state_dict(state)

Restore wrapper state and seek the backend to the saved position. Map functions themselves are not serialised — the caller must reconstruct the same map chain before calling load_state_dict (a fingerprint is checked).

Source code in src/forgather/ml/datasets/composable_iterable_dataset.py
def load_state_dict(self, state: Dict[str, Any]) -> None:
    """
    Restore wrapper state and seek the backend to the saved
    position. Map functions themselves are not serialised — the
    caller must reconstruct the same map chain before calling
    `load_state_dict` (a fingerprint is checked).
    """
    if state.get("wrapper_version") != 1:
        raise ValueError(f"Unknown wrapper_version: {state.get('wrapper_version')}")
    # Cheap shape guard. We deliberately do NOT compare function
    # identities or fingerprints — callable identity is not stable
    # across pickle / fork / process boundaries (e.g.
    # ``functools.partial.__repr__`` embeds the wrapped function's
    # memory address, which differs per process). Trying to match
    # those was producing false-positive failures in multi-worker
    # DataLoader resume.
    saved_n = state.get("n_maps")
    cur_n = len(self._maps)
    if saved_n is not None and saved_n != cur_n:
        raise ValueError(
            f"Map chain length mismatch on load_state_dict: saved "
            f"{saved_n} maps, current wrapper has {cur_n}. "
            "Reconstruct the wrapper with the same map chain before "
            "calling load_state_dict."
        )
    saved_batched = state.get("maps_batched")
    cur_batched = self._maps[0]["batched"] if self._maps else None
    if saved_batched is not None and saved_batched != cur_batched:
        raise ValueError(
            f"Map chain batched-mode mismatch on load_state_dict: "
            f"saved batched={saved_batched}, current batched="
            f"{cur_batched}."
        )
    self._split_start_idx = state.get("slice_start")
    self._split_end_idx = state.get("slice_end")
    self._shard = state.get("shard")
    self._base_shuffle_seed = state.get("base_shuffle_seed")
    self._epoch = state.get("epoch", 0)
    self._shuffle_buffer_size = state.get("shuffle_buffer_size")

    self.length_estimate_mode = state.get(
        "length_estimate_mode", self.length_estimate_mode
    )
    self._reset_length_on_iter = state.get(
        "reset_length_on_iter", self._reset_length_on_iter
    )
    self._input_count = int(state.get("input_count", 0))
    self._output_count = int(state.get("output_count", 0))
    self._cached_exact_length = state.get("cached_exact_length")
    self._length_invalidated = bool(state.get("length_invalidated", False))

    backend_state = state.get("backend", {})
    if (
        "state_dict" in backend_state
        and hasattr(self._backend, "load_state_dict")
        and callable(self._backend.load_state_dict)
    ):
        self._backend.load_state_dict(backend_state["state_dict"])
    else:
        self._backend = self._backend.seek(backend_state["position"])

    # Mark restored so the next iteration honors saved counts.
    self._restored_from_checkpoint = True

forgather.ml.datasets.iterable_backend.IterableDatasetBackend

Bases: ABC

Abstract storage backend for an iterable dataset.

The contract: __iter__ yields dict examples in some order, __len__ returns the total example count, position() reports the flat example index where the next iteration would start, and shuffle/seek return a new backend instance with the requested state change.

Implementations must:

  • Be deterministic given the same shuffle seed and seek position.
  • Update position() as iteration progresses, so a wrapper can capture it for state_dict at any point.
  • After shuffle(seed), position resets to 0 (the new instance is fresh). After seek(n), position is n.

Implementations may optionally expose:

  • column_names: list[str] — schema column names.
  • features — schema feature dict (HuggingFace-style).
  • n_shards: int — number of natural physical shards (e.g. files).

The wrapper forwards these via attribute access and tolerates AttributeError for backends that don't provide them.

Source code in src/forgather/ml/datasets/iterable_backend.py
class IterableDatasetBackend(ABC):
    """
    Abstract storage backend for an iterable dataset.

    The contract: `__iter__` yields `dict` examples in some order,
    `__len__` returns the total example count, `position()` reports
    the flat example index where the next iteration would start, and
    `shuffle`/`seek` return a new backend instance with the requested
    state change.

    Implementations must:

    - Be deterministic given the same shuffle seed and seek position.
    - Update `position()` as iteration progresses, so a wrapper can
      capture it for `state_dict` at any point.
    - After `shuffle(seed)`, position resets to 0 (the new instance is
      fresh). After `seek(n)`, position is `n`.

    Implementations may optionally expose:

    - `column_names: list[str]` — schema column names.
    - `features` — schema feature dict (HuggingFace-style).
    - `n_shards: int` — number of natural physical shards (e.g. files).

    The wrapper forwards these via attribute access and tolerates
    `AttributeError` for backends that don't provide them.
    """

    @abstractmethod
    def __iter__(self) -> Iterator[dict]:
        """Yield `dict` examples starting at `position()`."""

    @abstractmethod
    def __len__(self) -> int:
        """Total number of examples in the underlying dataset."""

    @abstractmethod
    def shuffle(self, seed: Optional[int] = None) -> "IterableDatasetBackend":
        """
        Return a new backend with the underlying example order
        re-permuted.

        No buffer parameter — the example-level reservoir buffer lives
        in the composing wrapper, not in the backend. The seed
        determines the new order deterministically; if ``None`` an
        implementation-chosen seed is generated and surfaced via the
        new instance's state so it can be reproduced from a checkpoint.

        The returned instance has `position()` reset to 0.
        """

    @abstractmethod
    def seek(self, position: int) -> "IterableDatasetBackend":
        """
        Return a new backend whose next `__iter__` begins at the given
        flat example index.

        Not expected to be O(1) — implementations may need to walk
        index metadata to translate the flat position into their
        internal representation. The returned instance has `position()`
        equal to the requested value.
        """

    @abstractmethod
    def position(self) -> int:
        """
        Current flat example index where the next `__iter__` would
        start.

        Must update during iteration so a wrapper can capture it for
        `state_dict()` after any number of yielded examples.
        """

__iter__() abstractmethod

Yield dict examples starting at position().

Source code in src/forgather/ml/datasets/iterable_backend.py
@abstractmethod
def __iter__(self) -> Iterator[dict]:
    """Yield `dict` examples starting at `position()`."""

__len__() abstractmethod

Total number of examples in the underlying dataset.

Source code in src/forgather/ml/datasets/iterable_backend.py
@abstractmethod
def __len__(self) -> int:
    """Total number of examples in the underlying dataset."""

shuffle(seed=None) abstractmethod

Return a new backend with the underlying example order re-permuted.

No buffer parameter — the example-level reservoir buffer lives in the composing wrapper, not in the backend. The seed determines the new order deterministically; if None an implementation-chosen seed is generated and surfaced via the new instance's state so it can be reproduced from a checkpoint.

The returned instance has position() reset to 0.

Source code in src/forgather/ml/datasets/iterable_backend.py
@abstractmethod
def shuffle(self, seed: Optional[int] = None) -> "IterableDatasetBackend":
    """
    Return a new backend with the underlying example order
    re-permuted.

    No buffer parameter — the example-level reservoir buffer lives
    in the composing wrapper, not in the backend. The seed
    determines the new order deterministically; if ``None`` an
    implementation-chosen seed is generated and surfaced via the
    new instance's state so it can be reproduced from a checkpoint.

    The returned instance has `position()` reset to 0.
    """

seek(position) abstractmethod

Return a new backend whose next __iter__ begins at the given flat example index.

Not expected to be O(1) — implementations may need to walk index metadata to translate the flat position into their internal representation. The returned instance has position() equal to the requested value.

Source code in src/forgather/ml/datasets/iterable_backend.py
@abstractmethod
def seek(self, position: int) -> "IterableDatasetBackend":
    """
    Return a new backend whose next `__iter__` begins at the given
    flat example index.

    Not expected to be O(1) — implementations may need to walk
    index metadata to translate the flat position into their
    internal representation. The returned instance has `position()`
    equal to the requested value.
    """

position() abstractmethod

Current flat example index where the next __iter__ would start.

Must update during iteration so a wrapper can capture it for state_dict() after any number of yielded examples.

Source code in src/forgather/ml/datasets/iterable_backend.py
@abstractmethod
def position(self) -> int:
    """
    Current flat example index where the next `__iter__` would
    start.

    Must update during iteration so a wrapper can capture it for
    `state_dict()` after any number of yielded examples.
    """

forgather.ml.datasets.arrow_backend.ArrowBackend

Bases: IterableDatasetBackend

Storage backend over a list of memory-mapped Arrow files.

Parameters:

Name Type Description Default
arrow_files list of str

Ordered list of Arrow file paths that make up the dataset. Each file is treated as one natural shard.

required
file_lengths list of int

Per-file example counts, parallel to arrow_files. When provided, __len__ and seek are O(num_files) without any file I/O. When None, file lengths are read on construction by opening each file (slow path; the loader normally avoids this by passing cached lengths from the on-disk index).

None
Notes

__iter__ mutates the cursor; multiple concurrent iterators on the same instance would interfere. In multi-worker DataLoader setups each worker receives its own copy (via fork or pickle), so concurrent cursors aren't an issue in practice.

Source code in src/forgather/ml/datasets/arrow_backend.py
class ArrowBackend(IterableDatasetBackend):
    """
    Storage backend over a list of memory-mapped Arrow files.

    Parameters
    ----------
    arrow_files : list of str
        Ordered list of Arrow file paths that make up the dataset.
        Each file is treated as one natural shard.
    file_lengths : list of int, optional
        Per-file example counts, parallel to ``arrow_files``. When
        provided, ``__len__`` and ``seek`` are O(num_files) without
        any file I/O. When ``None``, file lengths are read on
        construction by opening each file (slow path; the loader
        normally avoids this by passing cached lengths from the
        on-disk index).

    Notes
    -----
    `__iter__` mutates the cursor; multiple concurrent iterators on
    the same instance would interfere. In multi-worker DataLoader
    setups each worker receives its own copy (via fork or pickle), so
    concurrent cursors aren't an issue in practice.
    """

    def __init__(
        self,
        arrow_files: List[str],
        file_lengths: Optional[List[int]] = None,
    ):
        self.arrow_files: List[str] = list(arrow_files)
        if file_lengths is None:
            # Slow fallback — open each file to read its length.
            self.file_lengths = [len(Dataset.from_file(f)) for f in self.arrow_files]
        else:
            if len(file_lengths) != len(self.arrow_files):
                raise ValueError(
                    f"file_lengths length ({len(file_lengths)}) does not match "
                    f"arrow_files length ({len(self.arrow_files)})"
                )
            self.file_lengths = list(file_lengths)

        # Current iteration order (after any shuffle).
        self._seed: Optional[int] = None
        self._order_files: List[str] = self.arrow_files
        self._order_lengths: List[int] = self.file_lengths

        # Flat cursor — index of the NEXT example to yield.
        self._position: int = 0

        # Lazy schema cache (read from the first Arrow file).
        self._column_names: Optional[List[str]] = None
        self._features = None

    # ----- Backend interface -----

    def __iter__(self) -> Iterator[Dict[str, Any]]:
        """
        Walk Arrow files from the current cursor to end, yielding
        examples and updating ``_position`` as each one is emitted.

        ``_position`` is incremented BEFORE yield so callers that read
        it (e.g. the wrapper's `_iter_window` in check-then-consume
        mode) see "index of the next example" semantics consistently
        with InMemoryBackend / RemoteBackend.
        """
        cumul = 0
        for path, file_len in zip(self._order_files, self._order_lengths):
            file_end = cumul + file_len
            if file_end <= self._position:
                cumul = file_end
                continue
            local_start = max(0, self._position - cumul)
            ds = Dataset.from_file(path)
            if local_start > 0:
                ds = ds.select(range(local_start, file_len))
            for example in ds:
                self._position += 1
                yield example
            cumul = file_end

    def __len__(self) -> int:
        return sum(self.file_lengths)

    def shuffle(self, seed: Optional[int] = None) -> "ArrowBackend":
        """
        Return a new backend with files re-permuted under ``seed``.
        Cursor is reset to 0. No example-level buffer — that lives in
        the wrapper.
        """
        if seed is None:
            seed = random.randint(0, 2**31 - 1)
        new = ArrowBackend.__new__(ArrowBackend)
        new.arrow_files = self.arrow_files
        new.file_lengths = self.file_lengths
        new._seed = seed
        new._order_files, new._order_lengths = self._shuffled_order(
            self.arrow_files, self.file_lengths, seed
        )
        new._position = 0
        new._column_names = self._column_names
        new._features = self._features
        return new

    def seek(self, position: int) -> "ArrowBackend":
        """
        Return a new backend with the cursor at ``position``. Past-the-end
        positions are clamped to the end (next iteration yields nothing).
        """
        if position < 0:
            raise ValueError(f"position must be non-negative, got {position}")
        new = ArrowBackend.__new__(ArrowBackend)
        new.arrow_files = self.arrow_files
        new.file_lengths = self.file_lengths
        new._seed = self._seed
        new._order_files = self._order_files
        new._order_lengths = self._order_lengths
        new._position = min(position, len(self))
        new._column_names = self._column_names
        new._features = self._features
        return new

    def position(self) -> int:
        return self._position

    # ----- Optional metadata -----

    @property
    def column_names(self) -> List[str]:
        if self._column_names is None:
            if not self.arrow_files:
                return []
            self._column_names = Dataset.from_file(self.arrow_files[0]).column_names
        return self._column_names

    @property
    def features(self):
        if self._features is None:
            if not self.arrow_files:
                return None
            self._features = Dataset.from_file(self.arrow_files[0]).features
        return self._features

    @property
    def n_shards(self) -> int:
        return len(self.arrow_files)

    # ----- Optional checkpoint protocol (used by the wrapper) -----

    def state_dict(self) -> Dict[str, Any]:
        """
        Capture cursor + order seed + dataset-identity fingerprint.

        The wrapper picks this up via the optional-backend-state_dict
        path so a checkpoint round-trip can detect "different files
        behind the same handle" early.
        """
        return {
            "version": 1,
            "fingerprint": self._fingerprint(),
            "num_files": len(self.arrow_files),
            "seed": self._seed,
            "position": self._position,
        }

    def load_state_dict(self, state: Dict[str, Any]) -> None:
        if state.get("version") != 1:
            raise ValueError(
                f"Unknown ArrowBackend state version: {state.get('version')!r}"
            )
        saved_fp = state.get("fingerprint")
        if saved_fp is not None and saved_fp != self._fingerprint():
            raise ValueError(
                "Dataset fingerprint mismatch — checkpoint refers to a "
                "different set of Arrow files."
            )
        saved_n = state.get("num_files")
        if saved_n is not None and saved_n != len(self.arrow_files):
            raise ValueError(
                f"Number of files mismatch: checkpoint has {saved_n}, "
                f"backend has {len(self.arrow_files)}."
            )
        self._seed = state.get("seed")
        if self._seed is not None:
            self._order_files, self._order_lengths = self._shuffled_order(
                self.arrow_files, self.file_lengths, self._seed
            )
        else:
            self._order_files = self.arrow_files
            self._order_lengths = self.file_lengths
        self._position = int(state.get("position", 0))

    # ----- Helpers -----

    def _fingerprint(self) -> str:
        return hashlib.sha256("\n".join(self.arrow_files).encode("utf-8")).hexdigest()

    @staticmethod
    def _shuffled_order(
        files: List[str],
        lengths: List[int],
        seed: int,
    ) -> Tuple[List[str], List[int]]:
        rng = random.Random(seed)
        paired = list(zip(files, lengths))
        rng.shuffle(paired)
        files_shuffled, lengths_shuffled = zip(*paired)
        return list(files_shuffled), list(lengths_shuffled)

    def __repr__(self) -> str:
        return (
            f"ArrowBackend(files={len(self.arrow_files)}, "
            f"examples={len(self)}, position={self._position}, "
            f"seed={self._seed})"
        )

__iter__()

Walk Arrow files from the current cursor to end, yielding examples and updating _position as each one is emitted.

_position is incremented BEFORE yield so callers that read it (e.g. the wrapper's _iter_window in check-then-consume mode) see "index of the next example" semantics consistently with InMemoryBackend / RemoteBackend.

Source code in src/forgather/ml/datasets/arrow_backend.py
def __iter__(self) -> Iterator[Dict[str, Any]]:
    """
    Walk Arrow files from the current cursor to end, yielding
    examples and updating ``_position`` as each one is emitted.

    ``_position`` is incremented BEFORE yield so callers that read
    it (e.g. the wrapper's `_iter_window` in check-then-consume
    mode) see "index of the next example" semantics consistently
    with InMemoryBackend / RemoteBackend.
    """
    cumul = 0
    for path, file_len in zip(self._order_files, self._order_lengths):
        file_end = cumul + file_len
        if file_end <= self._position:
            cumul = file_end
            continue
        local_start = max(0, self._position - cumul)
        ds = Dataset.from_file(path)
        if local_start > 0:
            ds = ds.select(range(local_start, file_len))
        for example in ds:
            self._position += 1
            yield example
        cumul = file_end

shuffle(seed=None)

Return a new backend with files re-permuted under seed. Cursor is reset to 0. No example-level buffer — that lives in the wrapper.

Source code in src/forgather/ml/datasets/arrow_backend.py
def shuffle(self, seed: Optional[int] = None) -> "ArrowBackend":
    """
    Return a new backend with files re-permuted under ``seed``.
    Cursor is reset to 0. No example-level buffer — that lives in
    the wrapper.
    """
    if seed is None:
        seed = random.randint(0, 2**31 - 1)
    new = ArrowBackend.__new__(ArrowBackend)
    new.arrow_files = self.arrow_files
    new.file_lengths = self.file_lengths
    new._seed = seed
    new._order_files, new._order_lengths = self._shuffled_order(
        self.arrow_files, self.file_lengths, seed
    )
    new._position = 0
    new._column_names = self._column_names
    new._features = self._features
    return new

seek(position)

Return a new backend with the cursor at position. Past-the-end positions are clamped to the end (next iteration yields nothing).

Source code in src/forgather/ml/datasets/arrow_backend.py
def seek(self, position: int) -> "ArrowBackend":
    """
    Return a new backend with the cursor at ``position``. Past-the-end
    positions are clamped to the end (next iteration yields nothing).
    """
    if position < 0:
        raise ValueError(f"position must be non-negative, got {position}")
    new = ArrowBackend.__new__(ArrowBackend)
    new.arrow_files = self.arrow_files
    new.file_lengths = self.file_lengths
    new._seed = self._seed
    new._order_files = self._order_files
    new._order_lengths = self._order_lengths
    new._position = min(position, len(self))
    new._column_names = self._column_names
    new._features = self._features
    return new

state_dict()

Capture cursor + order seed + dataset-identity fingerprint.

The wrapper picks this up via the optional-backend-state_dict path so a checkpoint round-trip can detect "different files behind the same handle" early.

Source code in src/forgather/ml/datasets/arrow_backend.py
def state_dict(self) -> Dict[str, Any]:
    """
    Capture cursor + order seed + dataset-identity fingerprint.

    The wrapper picks this up via the optional-backend-state_dict
    path so a checkpoint round-trip can detect "different files
    behind the same handle" early.
    """
    return {
        "version": 1,
        "fingerprint": self._fingerprint(),
        "num_files": len(self.arrow_files),
        "seed": self._seed,
        "position": self._position,
    }

forgather.ml.datasets.remote_backend.RemoteBackend

Bases: IterableDatasetBackend

Network-proxy backend.

Parameters:

Name Type Description Default
url str

Base URL of the dataset server, e.g. "http://host:8766".

required
handle str

Server-side identifier for the registered backend to consume.

required
seed int

Shuffle seed; None means no shuffle requested.

None
position int

Initial flat example index. Default 0.

0
timeout float

Per-request HTTP timeout (seconds). Default 60.

60.0
token str

Explicit bearer token. If omitted, the constructor consults $FORGATHER_DATASET_SERVER_TOKEN and (for localhost URLs) <forgather_config_dir>/dataset_server/<port>.token.

None
Source code in src/forgather/ml/datasets/remote_backend.py
class RemoteBackend(IterableDatasetBackend):
    """
    Network-proxy backend.

    Parameters
    ----------
    url : str
        Base URL of the dataset server, e.g. ``"http://host:8766"``.
    handle : str
        Server-side identifier for the registered backend to consume.
    seed : int, optional
        Shuffle seed; ``None`` means no shuffle requested.
    position : int, optional
        Initial flat example index. Default ``0``.
    timeout : float, optional
        Per-request HTTP timeout (seconds). Default ``60``.
    token : str, optional
        Explicit bearer token. If omitted, the constructor consults
        ``$FORGATHER_DATASET_SERVER_TOKEN`` and (for localhost URLs)
        ``<forgather_config_dir>/dataset_server/<port>.token``.
    """

    def __init__(
        self,
        url: str,
        handle: str,
        seed: Optional[int] = None,
        position: int = 0,
        timeout: float = 60.0,
        token: Optional[str] = None,
        column_names: Optional[list[str]] = None,
    ):
        if position < 0:
            raise ValueError(f"position must be non-negative, got {position}")
        self._url = url.rstrip("/")
        self._handle = handle
        self._seed = seed
        self._position = position
        self._timeout = timeout
        # Resolved once at construction; if you change tokens, build a
        # new client. Most callers won't notice.
        self._token = resolve_auth_token(self._url, token)
        self._cached_len: Optional[int] = None
        # Schema cache. The loader passes column_names from the
        # /v1/load response so the client can answer column-aware
        # APIs (e.g. preprocess_dataset's remove_columns) without an
        # extra round trip; if not supplied here, we fetch it lazily
        # from /v1/datasets/{handle} on first access.
        self._column_names: Optional[list[str]] = (
            list(column_names) if column_names is not None else None
        )

    # ----- Backend interface -----

    def __iter__(self) -> Iterator[dict]:
        """
        Open a streaming /iter request from the current position and
        yield decoded examples. Updates ``self._position`` as each
        example arrives so callers can capture progress mid-stream.

        Network and 5xx errors at any point (initial open or mid-stream)
        are translated to :class:`DatasetServerUnreachable`. 4xx errors
        (token, bad request) propagate unchanged.
        """
        params: dict[str, str] = {"position": str(self._position)}
        if self._seed is not None:
            params["seed"] = str(self._seed)
        url = f"{self._url}/v1/datasets/{self._handle}/iter?{urlencode(params)}"
        req = Request(url, method="GET", headers=self._headers())
        try:
            resp = _dataset_urlopen(req, timeout=self._timeout, url=self._url)
        except Exception as exc:
            raise _translate_request_error(exc) from exc
        try:
            for raw in resp:
                line = raw.rstrip(b"\n")
                if not line:
                    continue
                example = _from_jsonable(json.loads(line.decode("utf-8")))
                self._position += 1
                yield example
        except Exception as exc:
            # Mid-stream socket drops surface as URLError / ConnectionError
            # from the response iterator; translate so the wrapper can
            # retry from the current (updated) position.
            translated = _translate_request_error(exc)
            if translated is exc:
                raise
            raise translated from exc
        finally:
            try:
                resp.close()
            except Exception:
                pass

    def __len__(self) -> int:
        if self._cached_len is None:
            url = f"{self._url}/v1/datasets/{self._handle}/length"
            req = Request(url, method="GET", headers=self._headers())
            try:
                with _dataset_urlopen(
                    req, timeout=self._timeout, url=self._url
                ) as resp:
                    payload = json.loads(resp.read().decode("utf-8"))
            except Exception as exc:
                raise _translate_request_error(exc) from exc
            self._cached_len = int(payload["length"])
        return self._cached_len

    def shuffle(self, seed: Optional[int] = None) -> "RemoteBackend":
        """
        Return a new client with the new seed; position resets to 0.

        No RPC is issued — the seed travels with the next ``/iter``
        request. The cached length is preserved (shuffling doesn't
        change the underlying example count).
        """
        new = RemoteBackend(
            self._url,
            self._handle,
            seed=seed,
            position=0,
            timeout=self._timeout,
            token=self._token,
        )
        new._cached_len = self._cached_len
        return new

    def seek(self, position: int) -> "RemoteBackend":
        """
        Return a new client positioned at the given flat example index.

        No RPC is issued — the position travels with the next
        ``/iter`` request.
        """
        if position < 0:
            raise ValueError(f"position must be non-negative, got {position}")
        new = RemoteBackend(
            self._url,
            self._handle,
            seed=self._seed,
            position=position,
            timeout=self._timeout,
            token=self._token,
        )
        new._cached_len = self._cached_len
        return new

    def position(self) -> int:
        return self._position

    # ----- Optional metadata -----

    @property
    def n_shards(self) -> int:
        # The remote layer doesn't expose physical sharding info; the
        # server may have any number of files behind the handle.
        return 1

    @property
    def column_names(self) -> Optional[list[str]]:
        """Column names of the underlying dataset.

        Populated either from the `/v1/load` response (most common —
        the loader passes them through) or by a lazy GET to
        `/v1/datasets/{handle}` on first access. Returns ``None`` if
        the server can't determine them.
        """
        if self._column_names is not None:
            return self._column_names
        url = f"{self._url}/v1/datasets/{self._handle}"
        req = Request(url, method="GET", headers=self._headers())
        try:
            with _dataset_urlopen(req, timeout=self._timeout, url=self._url) as resp:
                payload = json.loads(resp.read().decode("utf-8"))
        except Exception as exc:
            logger.debug("column_names lookup failed: %s", exc)
            return None
        cols = payload.get("column_names")
        if cols is not None:
            self._column_names = list(cols)
        return self._column_names

    # ----- helpers -----

    def _headers(self) -> dict[str, str]:
        if not self._token:
            return {}
        return {"Authorization": f"Bearer {self._token}"}

    def __repr__(self) -> str:
        return (
            f"RemoteBackend(url={self._url!r}, "
            f"handle={self._handle!r}, seed={self._seed}, "
            f"position={self._position})"
        )

column_names property

Column names of the underlying dataset.

Populated either from the /v1/load response (most common — the loader passes them through) or by a lazy GET to /v1/datasets/{handle} on first access. Returns None if the server can't determine them.

__iter__()

Open a streaming /iter request from the current position and yield decoded examples. Updates self._position as each example arrives so callers can capture progress mid-stream.

Network and 5xx errors at any point (initial open or mid-stream) are translated to :class:DatasetServerUnreachable. 4xx errors (token, bad request) propagate unchanged.

Source code in src/forgather/ml/datasets/remote_backend.py
def __iter__(self) -> Iterator[dict]:
    """
    Open a streaming /iter request from the current position and
    yield decoded examples. Updates ``self._position`` as each
    example arrives so callers can capture progress mid-stream.

    Network and 5xx errors at any point (initial open or mid-stream)
    are translated to :class:`DatasetServerUnreachable`. 4xx errors
    (token, bad request) propagate unchanged.
    """
    params: dict[str, str] = {"position": str(self._position)}
    if self._seed is not None:
        params["seed"] = str(self._seed)
    url = f"{self._url}/v1/datasets/{self._handle}/iter?{urlencode(params)}"
    req = Request(url, method="GET", headers=self._headers())
    try:
        resp = _dataset_urlopen(req, timeout=self._timeout, url=self._url)
    except Exception as exc:
        raise _translate_request_error(exc) from exc
    try:
        for raw in resp:
            line = raw.rstrip(b"\n")
            if not line:
                continue
            example = _from_jsonable(json.loads(line.decode("utf-8")))
            self._position += 1
            yield example
    except Exception as exc:
        # Mid-stream socket drops surface as URLError / ConnectionError
        # from the response iterator; translate so the wrapper can
        # retry from the current (updated) position.
        translated = _translate_request_error(exc)
        if translated is exc:
            raise
        raise translated from exc
    finally:
        try:
            resp.close()
        except Exception:
            pass

shuffle(seed=None)

Return a new client with the new seed; position resets to 0.

No RPC is issued — the seed travels with the next /iter request. The cached length is preserved (shuffling doesn't change the underlying example count).

Source code in src/forgather/ml/datasets/remote_backend.py
def shuffle(self, seed: Optional[int] = None) -> "RemoteBackend":
    """
    Return a new client with the new seed; position resets to 0.

    No RPC is issued — the seed travels with the next ``/iter``
    request. The cached length is preserved (shuffling doesn't
    change the underlying example count).
    """
    new = RemoteBackend(
        self._url,
        self._handle,
        seed=seed,
        position=0,
        timeout=self._timeout,
        token=self._token,
    )
    new._cached_len = self._cached_len
    return new

seek(position)

Return a new client positioned at the given flat example index.

No RPC is issued — the position travels with the next /iter request.

Source code in src/forgather/ml/datasets/remote_backend.py
def seek(self, position: int) -> "RemoteBackend":
    """
    Return a new client positioned at the given flat example index.

    No RPC is issued — the position travels with the next
    ``/iter`` request.
    """
    if position < 0:
        raise ValueError(f"position must be non-negative, got {position}")
    new = RemoteBackend(
        self._url,
        self._handle,
        seed=self._seed,
        position=position,
        timeout=self._timeout,
        token=self._token,
    )
    new._cached_len = self._cached_len
    return new

Interleaved Datasets

forgather.ml.datasets.interleaved.InterleavedDataset

Bases: IterableDataset

An iterable dataset that interleaves examples from multiple child datasets.

Works with any iterable dataset that supports the stateful checkpoint protocol (state_dict / load_state_dict), including ComposableIterableDataset. Designed for multi-dataset pre-training where examples from several corpora need to be mixed in a single training loop.

Parameters:

Name Type Description Default
datasets list

Child datasets to interleave. Must be non-empty. Each element can be any iterable; checkpointing is available for elements that implement state_dict / load_state_dict.

required
probabilities list of float or callable

Controls which child dataset is sampled at each step:

  • None (default) — round-robin: datasets are visited in order, cycling back to the first after the last.
  • list of float — static per-dataset weights. Values are normalised automatically; all must be non-negative and their sum must be positive.
  • callable — dynamic weight function called at each step with signature (step, datasets, examples_per_dataset, exhausted) -> list of float. See balance_remaining_examples for an example implementation.
None
seed int

Random seed for reproducible probabilistic sampling. Ignored when probabilities is None (round-robin).

None
stopping_strategy (first_exhausted, all_exhausted)

When to stop iteration:

  • "first_exhausted" (default) — stop as soon as any child dataset is exhausted.
  • "all_exhausted" — continue until every child dataset is exhausted, oversampling shorter datasets.
"first_exhausted"

Raises:

Type Description
ValueError

If datasets is empty, probabilities fail validation, or an unsupported stopping_strategy is given.

Examples:

>>> ds1 = fast_load_iterable_dataset("corpus_a", split="train")
>>> ds2 = fast_load_iterable_dataset("corpus_b", split="train")
>>> combined = InterleavedDataset([ds1, ds2], probabilities=[0.7, 0.3], seed=42)
>>> for example in combined:
...     pass
Source code in src/forgather/ml/datasets/interleaved.py
class InterleavedDataset(TorchIterableDataset):
    """
    An iterable dataset that interleaves examples from multiple child datasets.

    Works with any iterable dataset that supports the stateful checkpoint
    protocol (``state_dict`` / ``load_state_dict``), including
    `ComposableIterableDataset`. Designed for multi-dataset pre-training where
    examples from several corpora need to be mixed in a single training loop.

    Parameters
    ----------
    datasets : list
        Child datasets to interleave. Must be non-empty. Each element can be
        any iterable; checkpointing is available for elements that implement
        ``state_dict`` / ``load_state_dict``.
    probabilities : list of float or callable, optional
        Controls which child dataset is sampled at each step:

        - ``None`` (default) — round-robin: datasets are visited in order,
          cycling back to the first after the last.
        - ``list of float`` — static per-dataset weights. Values are
          normalised automatically; all must be non-negative and their sum
          must be positive.
        - ``callable`` — dynamic weight function called at each step with
          signature ``(step, datasets, examples_per_dataset, exhausted)
          -> list of float``. See `balance_remaining_examples` for an
          example implementation.
    seed : int, optional
        Random seed for reproducible probabilistic sampling. Ignored when
        ``probabilities`` is ``None`` (round-robin).
    stopping_strategy : {"first_exhausted", "all_exhausted"}, optional
        When to stop iteration:

        - ``"first_exhausted"`` (default) — stop as soon as any child dataset
          is exhausted.
        - ``"all_exhausted"`` — continue until every child dataset is
          exhausted, oversampling shorter datasets.

    Raises
    ------
    ValueError
        If ``datasets`` is empty, probabilities fail validation, or an
        unsupported ``stopping_strategy`` is given.

    Examples
    --------
    >>> ds1 = fast_load_iterable_dataset("corpus_a", split="train")
    >>> ds2 = fast_load_iterable_dataset("corpus_b", split="train")
    >>> combined = InterleavedDataset([ds1, ds2], probabilities=[0.7, 0.3], seed=42)
    >>> for example in combined:
    ...     pass
    """

    def __init__(
        self,
        datasets: List,
        probabilities: Optional[Union[List[float], Callable]] = None,
        seed: Optional[int] = None,
        stopping_strategy: str = "first_exhausted",
    ):
        if not datasets:
            raise ValueError("Cannot interleave empty list of datasets")

        self.datasets = datasets
        self.seed = seed
        self.stopping_strategy = stopping_strategy

        # Handle probabilities (static list or dynamic callable)
        self._probabilities_callable = callable(probabilities)
        if self._probabilities_callable:
            # Store the callable function
            self._probabilities_fn = probabilities
            self.probabilities = None  # Will be computed dynamically
        else:
            self._probabilities_fn = None
            self.probabilities = probabilities

            # Validate static probabilities
            if probabilities is not None:
                if len(probabilities) != len(datasets):
                    raise ValueError(
                        f"Probabilities length ({len(probabilities)}) must match datasets length ({len(datasets)})"
                    )
                if not all(p >= 0 for p in probabilities):
                    raise ValueError("All probabilities must be non-negative")
                prob_sum = sum(probabilities)
                if prob_sum == 0:
                    raise ValueError("At least one probability must be > 0")
                # Normalize probabilities
                self.probabilities = [p / prob_sum for p in probabilities]

        # Validate stopping strategy
        if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
            raise ValueError(
                f"Unsupported stopping_strategy: {stopping_strategy}. "
                f"Use 'first_exhausted' or 'all_exhausted'"
            )

        # Checkpoint state - which dataset and position within it
        self._current_dataset_index = 0
        self._current_example_count = 0  # Total examples yielded
        self._datasets_exhausted = [False] * len(datasets)

    def __repr__(self):
        s = "InterleavedDataset(\n"
        for ds in self.datasets:
            s += "  " + repr(ds) + ",\n"
        s += "  probabilities=" + repr(self.probabilities) + ",\n"
        s += "  seed=" + repr(self.seed) + ",\n"
        s += "  stopping_strategy=" + repr(self.stopping_strategy) + ",\n"
        s += "  _probabilities_fn=" + repr(self._probabilities_fn) + "\n)\n"
        return s

    def __iter__(self):
        """
        Yield interleaved examples from all child datasets.

        Selects which child to draw from at each step using the configured
        sampling strategy (round-robin or probabilistic). Stops according to
        ``stopping_strategy``. If `load_state_dict` was called before
        iteration, child iterators are fast-forwarded to their checkpointed
        positions automatically.

        Yields
        ------
        dict
            One example per step from whichever child dataset was selected.
        """
        import random

        # Create iterators for all datasets
        iterators = [iter(dataset) for dataset in self.datasets]
        exhausted = [False] * len(self.datasets)

        # Track examples per dataset (for dynamic probabilities and checkpointing)
        examples_per_dataset = [0] * len(self.datasets)

        # Restore from checkpoint if available
        if hasattr(self, "_restored_examples_per_dataset"):
            examples_per_dataset = self._restored_examples_per_dataset.copy()
            delattr(self, "_restored_examples_per_dataset")

        # Setup RNG if using probabilities (static or dynamic)
        use_probabilities = (
            self.probabilities is not None or self._probabilities_callable
        )
        rng = random.Random(self.seed) if use_probabilities else None

        # Track how many examples we've yielded (for checkpoint restoration)
        examples_yielded = 0

        # For round-robin, track current index
        current_idx = 0

        # Track iteration step for dynamic probabilities
        step = 0

        while True:
            # Check stopping condition (only for first_exhausted here)
            if self.stopping_strategy == "first_exhausted":
                if any(exhausted):
                    break

            # Choose which dataset to sample from
            if use_probabilities:
                # Probabilistic sampling from non-exhausted datasets
                available_indices = [i for i, ex in enumerate(exhausted) if not ex]
                if not available_indices:
                    break

                # Get current probabilities (static or dynamic)
                if self._probabilities_callable:
                    # Call dynamic probability function
                    current_probs = self._probabilities_fn(
                        step, self.datasets, examples_per_dataset, exhausted
                    )
                    # Validate returned probabilities
                    if len(current_probs) != len(self.datasets):
                        raise ValueError(
                            f"Probability function returned {len(current_probs)} values, "
                            f"expected {len(self.datasets)}"
                        )
                else:
                    # Use static probabilities
                    current_probs = self.probabilities

                # Compute probabilities for available datasets only
                available_probs = [current_probs[i] for i in available_indices]
                prob_sum = sum(available_probs)
                if prob_sum == 0:
                    break
                normalized_probs = [p / prob_sum for p in available_probs]

                # Sample from available datasets
                chosen_idx = rng.choices(available_indices, weights=normalized_probs)[0]
            else:
                # Round-robin through non-exhausted datasets
                attempts = 0
                while exhausted[current_idx] and attempts < len(self.datasets):
                    current_idx = (current_idx + 1) % len(self.datasets)
                    attempts += 1

                if exhausted[current_idx]:
                    # All exhausted
                    break

                chosen_idx = current_idx
                current_idx = (current_idx + 1) % len(self.datasets)

            # Try to get next example from chosen dataset
            try:
                example = next(iterators[chosen_idx])
                examples_yielded += 1
                examples_per_dataset[chosen_idx] += 1
                step += 1

                # Update checkpoint position
                self._current_dataset_index = chosen_idx
                self._current_example_count = examples_yielded

                # Save examples_per_dataset for checkpointing
                self._examples_per_dataset_checkpoint = examples_per_dataset.copy()

                yield example

            except StopIteration:
                # Mark as exhausted
                exhausted[chosen_idx] = True
                self._datasets_exhausted[chosen_idx] = True

                # For all_exhausted, continue with remaining datasets
                if self.stopping_strategy == "all_exhausted":
                    # Check if all are now exhausted
                    if all(exhausted):
                        break
                # For first_exhausted, we break at top of loop (already checked)

    def __len__(self) -> int:
        """
        Return an estimate of the total number of examples that will be yielded.

        The estimate depends on the ``stopping_strategy``:

        - ``"first_exhausted"`` with round-robin — ``min(child_lengths) * num_datasets``.
        - ``"first_exhausted"`` with probabilities — ``sum(child_lengths)``
          (approximation; exact calculation is complex).
        - ``"all_exhausted"`` — ``sum(child_lengths)`` regardless of sampling mode.

        Returns
        -------
        int
            Estimated total example count.
        """
        dataset_lengths = [len(ds) for ds in self.datasets]

        if self.stopping_strategy == "first_exhausted":
            if self.probabilities is None:
                # Round-robin: min_length * num_datasets
                return min(dataset_lengths) * len(self.datasets)
            else:
                # With probabilities: complex calculation
                # Approximate as maximum samples, with balanced datasets
                # TODO: Improve on this!
                min_length = sum(dataset_lengths)
                return min_length
        else:  # all_exhausted
            if self.probabilities is None:
                # Round-robin: each dataset visited once fully
                return sum(dataset_lengths)
            else:
                # With probabilities: max samples needed to exhaust all
                # This is an approximation
                return sum(dataset_lengths)

    @property
    def column_names(self) -> List[str]:
        """Get column names from first dataset."""
        if not self.datasets:
            return []
        if hasattr(self.datasets[0], "column_names"):
            return self.datasets[0].column_names
        return []

    @property
    def features(self):
        """Get features from first dataset."""
        if not self.datasets:
            return None
        if hasattr(self.datasets[0], "features"):
            return self.datasets[0].features
        return None

    @property
    def n_shards(self) -> int:
        """Total number of shards across all datasets."""
        total = 0
        for ds in self.datasets:
            if hasattr(ds, "n_shards"):
                total += ds.n_shards
            else:
                total += 1  # Assume 1 shard if not specified
        return total

    def state_dict(self) -> Dict[str, Any]:
        """
        Serialize the interleaving position and all child dataset states.

        Returns
        -------
        dict
            Dictionary with the following keys:

            ``current_dataset_index``
                Index of the child dataset that was most recently sampled.
            ``current_example_count``
                Total examples yielded so far across all children.
            ``datasets_exhausted``
                Boolean list indicating which children are exhausted.
            ``probabilities``
                Normalised static probabilities (``None`` if round-robin or
                dynamic).
            ``seed``
                Random seed.
            ``stopping_strategy``
                Configured stopping strategy string.
            ``child_states``
                List of per-child state dicts (``None`` for children that do
                not implement ``state_dict``).
            ``examples_per_dataset``
                Per-child example counts at the time of the last yield
                (present only when available; required for dynamic probability
                functions).
        """
        state = {
            "current_dataset_index": self._current_dataset_index,
            "current_example_count": self._current_example_count,
            "datasets_exhausted": self._datasets_exhausted.copy(),
            "probabilities": self.probabilities,
            "seed": self.seed,
            "stopping_strategy": self.stopping_strategy,
            "child_states": [],
        }

        # Save examples_per_dataset if available (needed for dynamic probability functions)
        if hasattr(self, "_examples_per_dataset_checkpoint"):
            state["examples_per_dataset"] = self._examples_per_dataset_checkpoint.copy()

        # Save state for each child dataset
        for i, dataset in enumerate(self.datasets):
            if hasattr(dataset, "state_dict"):
                state["child_states"].append(dataset.state_dict())
            else:
                # Dataset doesn't support state_dict, save None
                state["child_states"].append(None)

        return state

    def load_state_dict(self, state_dict: Dict[str, Any]):
        """
        Restore the interleaving position and all child dataset states.

        After calling this method, the next iteration resumes from the saved
        position. Child datasets that implement ``load_state_dict`` are
        restored individually; others are left at their natural start position.

        Parameters
        ----------
        state_dict : dict
            Dictionary previously returned by `state_dict`.
        """
        self._current_dataset_index = state_dict["current_dataset_index"]
        self._current_example_count = state_dict["current_example_count"]
        self._datasets_exhausted = state_dict.get(
            "datasets_exhausted", [False] * len(self.datasets)
        )

        # Restore examples_per_dataset if available (for dynamic probability functions)
        if "examples_per_dataset" in state_dict:
            self._restored_examples_per_dataset = state_dict[
                "examples_per_dataset"
            ].copy()

        # Restore state for each child dataset
        child_states = state_dict.get("child_states", [])
        for i, (dataset, child_state) in enumerate(zip(self.datasets, child_states)):
            if child_state is not None and hasattr(dataset, "load_state_dict"):
                dataset.load_state_dict(child_state)

column_names property

Get column names from first dataset.

features property

Get features from first dataset.

n_shards property

Total number of shards across all datasets.

__iter__()

Yield interleaved examples from all child datasets.

Selects which child to draw from at each step using the configured sampling strategy (round-robin or probabilistic). Stops according to stopping_strategy. If load_state_dict was called before iteration, child iterators are fast-forwarded to their checkpointed positions automatically.

Yields:

Type Description
dict

One example per step from whichever child dataset was selected.

Source code in src/forgather/ml/datasets/interleaved.py
def __iter__(self):
    """
    Yield interleaved examples from all child datasets.

    Selects which child to draw from at each step using the configured
    sampling strategy (round-robin or probabilistic). Stops according to
    ``stopping_strategy``. If `load_state_dict` was called before
    iteration, child iterators are fast-forwarded to their checkpointed
    positions automatically.

    Yields
    ------
    dict
        One example per step from whichever child dataset was selected.
    """
    import random

    # Create iterators for all datasets
    iterators = [iter(dataset) for dataset in self.datasets]
    exhausted = [False] * len(self.datasets)

    # Track examples per dataset (for dynamic probabilities and checkpointing)
    examples_per_dataset = [0] * len(self.datasets)

    # Restore from checkpoint if available
    if hasattr(self, "_restored_examples_per_dataset"):
        examples_per_dataset = self._restored_examples_per_dataset.copy()
        delattr(self, "_restored_examples_per_dataset")

    # Setup RNG if using probabilities (static or dynamic)
    use_probabilities = (
        self.probabilities is not None or self._probabilities_callable
    )
    rng = random.Random(self.seed) if use_probabilities else None

    # Track how many examples we've yielded (for checkpoint restoration)
    examples_yielded = 0

    # For round-robin, track current index
    current_idx = 0

    # Track iteration step for dynamic probabilities
    step = 0

    while True:
        # Check stopping condition (only for first_exhausted here)
        if self.stopping_strategy == "first_exhausted":
            if any(exhausted):
                break

        # Choose which dataset to sample from
        if use_probabilities:
            # Probabilistic sampling from non-exhausted datasets
            available_indices = [i for i, ex in enumerate(exhausted) if not ex]
            if not available_indices:
                break

            # Get current probabilities (static or dynamic)
            if self._probabilities_callable:
                # Call dynamic probability function
                current_probs = self._probabilities_fn(
                    step, self.datasets, examples_per_dataset, exhausted
                )
                # Validate returned probabilities
                if len(current_probs) != len(self.datasets):
                    raise ValueError(
                        f"Probability function returned {len(current_probs)} values, "
                        f"expected {len(self.datasets)}"
                    )
            else:
                # Use static probabilities
                current_probs = self.probabilities

            # Compute probabilities for available datasets only
            available_probs = [current_probs[i] for i in available_indices]
            prob_sum = sum(available_probs)
            if prob_sum == 0:
                break
            normalized_probs = [p / prob_sum for p in available_probs]

            # Sample from available datasets
            chosen_idx = rng.choices(available_indices, weights=normalized_probs)[0]
        else:
            # Round-robin through non-exhausted datasets
            attempts = 0
            while exhausted[current_idx] and attempts < len(self.datasets):
                current_idx = (current_idx + 1) % len(self.datasets)
                attempts += 1

            if exhausted[current_idx]:
                # All exhausted
                break

            chosen_idx = current_idx
            current_idx = (current_idx + 1) % len(self.datasets)

        # Try to get next example from chosen dataset
        try:
            example = next(iterators[chosen_idx])
            examples_yielded += 1
            examples_per_dataset[chosen_idx] += 1
            step += 1

            # Update checkpoint position
            self._current_dataset_index = chosen_idx
            self._current_example_count = examples_yielded

            # Save examples_per_dataset for checkpointing
            self._examples_per_dataset_checkpoint = examples_per_dataset.copy()

            yield example

        except StopIteration:
            # Mark as exhausted
            exhausted[chosen_idx] = True
            self._datasets_exhausted[chosen_idx] = True

            # For all_exhausted, continue with remaining datasets
            if self.stopping_strategy == "all_exhausted":
                # Check if all are now exhausted
                if all(exhausted):
                    break

__len__()

Return an estimate of the total number of examples that will be yielded.

The estimate depends on the stopping_strategy:

  • "first_exhausted" with round-robin — min(child_lengths) * num_datasets.
  • "first_exhausted" with probabilities — sum(child_lengths) (approximation; exact calculation is complex).
  • "all_exhausted" — sum(child_lengths) regardless of sampling mode.

Returns:

Type Description
int

Estimated total example count.

Source code in src/forgather/ml/datasets/interleaved.py
def __len__(self) -> int:
    """
    Return an estimate of the total number of examples that will be yielded.

    The estimate depends on the ``stopping_strategy``:

    - ``"first_exhausted"`` with round-robin — ``min(child_lengths) * num_datasets``.
    - ``"first_exhausted"`` with probabilities — ``sum(child_lengths)``
      (approximation; exact calculation is complex).
    - ``"all_exhausted"`` — ``sum(child_lengths)`` regardless of sampling mode.

    Returns
    -------
    int
        Estimated total example count.
    """
    dataset_lengths = [len(ds) for ds in self.datasets]

    if self.stopping_strategy == "first_exhausted":
        if self.probabilities is None:
            # Round-robin: min_length * num_datasets
            return min(dataset_lengths) * len(self.datasets)
        else:
            # With probabilities: complex calculation
            # Approximate as maximum samples, with balanced datasets
            # TODO: Improve on this!
            min_length = sum(dataset_lengths)
            return min_length
    else:  # all_exhausted
        if self.probabilities is None:
            # Round-robin: each dataset visited once fully
            return sum(dataset_lengths)
        else:
            # With probabilities: max samples needed to exhaust all
            # This is an approximation
            return sum(dataset_lengths)

state_dict()

Serialize the interleaving position and all child dataset states.

Returns:

Type Description
dict

Dictionary with the following keys:

current_dataset_index Index of the child dataset that was most recently sampled. current_example_count Total examples yielded so far across all children. datasets_exhausted Boolean list indicating which children are exhausted. probabilities Normalised static probabilities (None if round-robin or dynamic). seed Random seed. stopping_strategy Configured stopping strategy string. child_states List of per-child state dicts (None for children that do not implement state_dict). examples_per_dataset Per-child example counts at the time of the last yield (present only when available; required for dynamic probability functions).

Source code in src/forgather/ml/datasets/interleaved.py
def state_dict(self) -> Dict[str, Any]:
    """
    Serialize the interleaving position and all child dataset states.

    Returns
    -------
    dict
        Dictionary with the following keys:

        ``current_dataset_index``
            Index of the child dataset that was most recently sampled.
        ``current_example_count``
            Total examples yielded so far across all children.
        ``datasets_exhausted``
            Boolean list indicating which children are exhausted.
        ``probabilities``
            Normalised static probabilities (``None`` if round-robin or
            dynamic).
        ``seed``
            Random seed.
        ``stopping_strategy``
            Configured stopping strategy string.
        ``child_states``
            List of per-child state dicts (``None`` for children that do
            not implement ``state_dict``).
        ``examples_per_dataset``
            Per-child example counts at the time of the last yield
            (present only when available; required for dynamic probability
            functions).
    """
    state = {
        "current_dataset_index": self._current_dataset_index,
        "current_example_count": self._current_example_count,
        "datasets_exhausted": self._datasets_exhausted.copy(),
        "probabilities": self.probabilities,
        "seed": self.seed,
        "stopping_strategy": self.stopping_strategy,
        "child_states": [],
    }

    # Save examples_per_dataset if available (needed for dynamic probability functions)
    if hasattr(self, "_examples_per_dataset_checkpoint"):
        state["examples_per_dataset"] = self._examples_per_dataset_checkpoint.copy()

    # Save state for each child dataset
    for i, dataset in enumerate(self.datasets):
        if hasattr(dataset, "state_dict"):
            state["child_states"].append(dataset.state_dict())
        else:
            # Dataset doesn't support state_dict, save None
            state["child_states"].append(None)

    return state

load_state_dict(state_dict)

Restore the interleaving position and all child dataset states.

After calling this method, the next iteration resumes from the saved position. Child datasets that implement load_state_dict are restored individually; others are left at their natural start position.

Parameters:

Name Type Description Default
state_dict dict

Dictionary previously returned by state_dict.

required
Source code in src/forgather/ml/datasets/interleaved.py
def load_state_dict(self, state_dict: Dict[str, Any]):
    """
    Restore the interleaving position and all child dataset states.

    After calling this method, the next iteration resumes from the saved
    position. Child datasets that implement ``load_state_dict`` are
    restored individually; others are left at their natural start position.

    Parameters
    ----------
    state_dict : dict
        Dictionary previously returned by `state_dict`.
    """
    self._current_dataset_index = state_dict["current_dataset_index"]
    self._current_example_count = state_dict["current_example_count"]
    self._datasets_exhausted = state_dict.get(
        "datasets_exhausted", [False] * len(self.datasets)
    )

    # Restore examples_per_dataset if available (for dynamic probability functions)
    if "examples_per_dataset" in state_dict:
        self._restored_examples_per_dataset = state_dict[
            "examples_per_dataset"
        ].copy()

    # Restore state for each child dataset
    child_states = state_dict.get("child_states", [])
    for i, (dataset, child_state) in enumerate(zip(self.datasets, child_states)):
        if child_state is not None and hasattr(dataset, "load_state_dict"):
            dataset.load_state_dict(child_state)

Utilities

forgather.ml.datasets.iterable_with_length.IterableDatasetWithLength

Bases: IterableDataset

A thin wrapper that adds a known length to an iterable dataset.

PyTorch's IterableDataset does not require __len__, but trainers and data-loader utilities often query it to calculate epoch step counts. When a map-style Dataset is converted to an iterable form with to_iterable_dataset(), the length information is lost. This wrapper re-attaches it.

All attribute and method accesses that are not handled by this class are forwarded transparently to the wrapped dataset via __getattr__, including state_dict / load_state_dict for checkpointing, and HuggingFace Dataset attributes such as column_names and features.

Parameters:

Name Type Description Default
iterable_dataset IterableDataset

The dataset to wrap. Any iterable dataset is accepted.

required
length int

The length to report from __len__. This value is not validated against the actual number of examples; the caller is responsible for supplying a consistent value.

required
Notes

The map and shuffle methods are overridden to return a new IterableDatasetWithLength with the same reported length, so that the length is preserved through chained transformations.

filter is not overridden: the filtered dataset is returned as-is because the new length cannot be determined without iterating.

Examples:

>>> from torch.utils.data import IterableDataset
>>> ds = some_map_style_dataset.to_iterable_dataset()
>>> ds_with_len = IterableDatasetWithLength(ds, length=len(some_map_style_dataset))
>>> len(ds_with_len)
1000
Source code in src/forgather/ml/datasets/iterable_with_length.py
class IterableDatasetWithLength(IterableDataset):
    """
    A thin wrapper that adds a known length to an iterable dataset.

    PyTorch's ``IterableDataset`` does not require ``__len__``, but trainers
    and data-loader utilities often query it to calculate epoch step counts.
    When a map-style ``Dataset`` is converted to an iterable form with
    ``to_iterable_dataset()``, the length information is lost. This wrapper
    re-attaches it.

    All attribute and method accesses that are not handled by this class are
    forwarded transparently to the wrapped dataset via ``__getattr__``,
    including ``state_dict`` / ``load_state_dict`` for checkpointing, and
    HuggingFace Dataset attributes such as ``column_names`` and ``features``.

    Parameters
    ----------
    iterable_dataset : IterableDataset
        The dataset to wrap. Any iterable dataset is accepted.
    length : int
        The length to report from ``__len__``. This value is not validated
        against the actual number of examples; the caller is responsible for
        supplying a consistent value.

    Notes
    -----
    The `map` and `shuffle` methods are overridden to return a new
    ``IterableDatasetWithLength`` with the same reported length, so that
    the length is preserved through chained transformations.

    `filter` is *not* overridden: the filtered dataset is returned as-is
    because the new length cannot be determined without iterating.

    Examples
    --------
    >>> from torch.utils.data import IterableDataset
    >>> ds = some_map_style_dataset.to_iterable_dataset()
    >>> ds_with_len = IterableDatasetWithLength(ds, length=len(some_map_style_dataset))
    >>> len(ds_with_len)
    1000
    """

    def __init__(self, iterable_dataset, length: int):
        self._dataset = iterable_dataset
        self._length = length

    def __len__(self) -> int:
        return self._length

    def __iter__(self):
        return iter(self._dataset)

    def __repr__(self):
        return f"IterableDatasetWithLength({repr(self._dataset)}, length={repr(self._length)})"

    def map(self, *args, **kwargs):
        """
        Apply a map transformation while preserving the reported length.

        Delegates to the wrapped dataset's ``map`` method and re-wraps the
        result in a new ``IterableDatasetWithLength`` with the same length.

        Parameters
        ----------
        *args
            Positional arguments forwarded to the wrapped dataset's ``map``.
        **kwargs
            Keyword arguments forwarded to the wrapped dataset's ``map``.

        Returns
        -------
        IterableDatasetWithLength
            Mapped dataset with the same reported length as this instance.
        """
        mapped_dataset = self._dataset.map(*args, **kwargs)
        return IterableDatasetWithLength(mapped_dataset, self._length)

    def shuffle(self, *args, **kwargs):
        """
        Shuffle the dataset while preserving the reported length.

        Delegates to the wrapped dataset's ``shuffle`` method and re-wraps the
        result in a new ``IterableDatasetWithLength`` with the same length.

        Parameters
        ----------
        *args
            Positional arguments forwarded to the wrapped dataset's ``shuffle``.
        **kwargs
            Keyword arguments forwarded to the wrapped dataset's ``shuffle``.

        Returns
        -------
        IterableDatasetWithLength
            Shuffled dataset with the same reported length as this instance.
        """
        shuffled_dataset = self._dataset.shuffle(*args, **kwargs)
        return IterableDatasetWithLength(shuffled_dataset, self._length)

    def filter(self, *args, **kwargs):
        """
        Filter the dataset.

        Delegates to the wrapped dataset's ``filter`` method. The length
        information is *not* preserved because the post-filter count cannot
        be determined without iterating.

        Parameters
        ----------
        *args
            Positional arguments forwarded to the wrapped dataset's ``filter``.
        **kwargs
            Keyword arguments forwarded to the wrapped dataset's ``filter``.

        Returns
        -------
        IterableDataset
            Filtered dataset without a ``__len__`` method.
        """
        filtered_dataset = self._dataset.filter(*args, **kwargs)
        # We can't know the new length after filtering, so we lose it
        return filtered_dataset

    def __getattr__(self, name):
        """Forward all unknown attribute and method accesses to the wrapped dataset."""
        return getattr(self._dataset, name)

map(*args, **kwargs)

Apply a map transformation while preserving the reported length.

Delegates to the wrapped dataset's map method and re-wraps the result in a new IterableDatasetWithLength with the same length.

Parameters:

Name Type Description Default
*args

Positional arguments forwarded to the wrapped dataset's map.

()
**kwargs

Keyword arguments forwarded to the wrapped dataset's map.

{}

Returns:

Type Description
IterableDatasetWithLength

Mapped dataset with the same reported length as this instance.

Source code in src/forgather/ml/datasets/iterable_with_length.py
def map(self, *args, **kwargs):
    """
    Apply a map transformation while preserving the reported length.

    Delegates to the wrapped dataset's ``map`` method and re-wraps the
    result in a new ``IterableDatasetWithLength`` with the same length.

    Parameters
    ----------
    *args
        Positional arguments forwarded to the wrapped dataset's ``map``.
    **kwargs
        Keyword arguments forwarded to the wrapped dataset's ``map``.

    Returns
    -------
    IterableDatasetWithLength
        Mapped dataset with the same reported length as this instance.
    """
    mapped_dataset = self._dataset.map(*args, **kwargs)
    return IterableDatasetWithLength(mapped_dataset, self._length)

shuffle(*args, **kwargs)

Shuffle the dataset while preserving the reported length.

Delegates to the wrapped dataset's shuffle method and re-wraps the result in a new IterableDatasetWithLength with the same length.

Parameters:

Name Type Description Default
*args

Positional arguments forwarded to the wrapped dataset's shuffle.

()
**kwargs

Keyword arguments forwarded to the wrapped dataset's shuffle.

{}

Returns:

Type Description
IterableDatasetWithLength

Shuffled dataset with the same reported length as this instance.

Source code in src/forgather/ml/datasets/iterable_with_length.py
def shuffle(self, *args, **kwargs):
    """
    Shuffle the dataset while preserving the reported length.

    Delegates to the wrapped dataset's ``shuffle`` method and re-wraps the
    result in a new ``IterableDatasetWithLength`` with the same length.

    Parameters
    ----------
    *args
        Positional arguments forwarded to the wrapped dataset's ``shuffle``.
    **kwargs
        Keyword arguments forwarded to the wrapped dataset's ``shuffle``.

    Returns
    -------
    IterableDatasetWithLength
        Shuffled dataset with the same reported length as this instance.
    """
    shuffled_dataset = self._dataset.shuffle(*args, **kwargs)
    return IterableDatasetWithLength(shuffled_dataset, self._length)

filter(*args, **kwargs)

Filter the dataset.

Delegates to the wrapped dataset's filter method. The length information is not preserved because the post-filter count cannot be determined without iterating.

Parameters:

Name Type Description Default
*args

Positional arguments forwarded to the wrapped dataset's filter.

()
**kwargs

Keyword arguments forwarded to the wrapped dataset's filter.

{}

Returns:

Type Description
IterableDataset

Filtered dataset without a __len__ method.

Source code in src/forgather/ml/datasets/iterable_with_length.py
def filter(self, *args, **kwargs):
    """
    Filter the dataset.

    Delegates to the wrapped dataset's ``filter`` method. The length
    information is *not* preserved because the post-filter count cannot
    be determined without iterating.

    Parameters
    ----------
    *args
        Positional arguments forwarded to the wrapped dataset's ``filter``.
    **kwargs
        Keyword arguments forwarded to the wrapped dataset's ``filter``.

    Returns
    -------
    IterableDataset
        Filtered dataset without a ``__len__`` method.
    """
    filtered_dataset = self._dataset.filter(*args, **kwargs)
    # We can't know the new length after filtering, so we lose it
    return filtered_dataset

__getattr__(name)

Forward all unknown attribute and method accesses to the wrapped dataset.

Source code in src/forgather/ml/datasets/iterable_with_length.py
def __getattr__(self, name):
    """Forward all unknown attribute and method accesses to the wrapped dataset."""
    return getattr(self._dataset, name)