WFORunner

Source code in wt_ml/tuning/wfo_runner.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
class WFORunner:
    def __init__(
        self,
        dataset: EconomicDataset,
        model_builder: ModelBuilder,
        val_freq: Frequency | str | None = None,
        min_val_date: str | np.datetime64 | None = None,
        max_val_date: str | np.datetime64 | None = None,
        start_date: str | np.datetime64 | None = None,
        custom_val_periods: dict[str, ValidationPeriod] | Sequence[ValidationPeriod] | None = None,
        model_name_prefix: str = "Model",
    ):
        """
        WFO object that creates periods and runs walk forward optimization.
        Either pass in val_freq or custom_val_periods to create wfo periods.
        The model_builder is used to create the model used for `train_test_model`.
        This is separated from WFO, so that it's cleaner and makes it easier during parallelization.
        You can use `build_model` from `wt_ml.networks.model` and create a partial function from it assigning
        `hyperparameters` and net_combination.

        ``` python
        from functools import partial
        from wt_ml.networks.model import build_model


        model_builder = partial(
            _build_model,
            hyperparameters=hyperparameters,
            net_combination=net_combination,
        )
        ```

        Args:
            dataset (EconomicDataset): The dataset object.
            model_builder (ModelBuilder): Function that creates the model and has the ModelBuilder signature.
            val_freq (Frequency | str | None): Validation period frequency. Defaults to None.
            min_val_date (str | np.datetime64 | None): Minimum validation week to include in periods.
                If None, ensures that there is at least 1 period to train on. Defaults to None.
            max_val_date (str | np.datetime64 | None, optional): Maximum validation week to include in periods.
                                                                Defaults to None.
            start_date (str | np.datetime64 | None, optional): The start date after which we start training.
                                                                Defaults to None.
            custom_val_periods (dict[str, ValidationPeriod] | Sequence[ValidationPeriod] | None):
                Provide custom validation periods instead of generating from `val_freq `. Defaults to None.
            model_name_prefix (str, optional): Prefix name given when we build the model. Defaults to "Model".
        """
        # TODO (@legendof-selda): this should be done via a classmethod instead?
        if custom_val_periods is None and val_freq is None:
            raise ValueError("Provide either custom validation periods or validation frequency to run WFO periods.")

        self.model_name_prefix = model_name_prefix
        self.dataset = dataset
        self.model_builder = model_builder
        if val_freq is not None and not isinstance(val_freq, Frequency):
            val_freq = Frequency[val_freq]

        self._val_freq = val_freq
        if isinstance(custom_val_periods, Sequence):
            custom_val_periods = {f"period_{i}": val_period for i, val_period in enumerate(custom_val_periods)}

        self.custom_val_periods: dict[str, ValidationPeriod] | None = custom_val_periods
        self._max_val_date = np.datetime64(max_val_date if max_val_date else self.dates[-1])
        # date_index can start anywhere, dates will always have all the dates.
        # we also assume that date_index is always sorted in asc order.
        start_date_index = int(next(self.dataset()).date_index.numpy()[0])
        self._start_date = np.datetime64(start_date if start_date else self.dates[start_date_index])

        if min_val_date is None and self.val_freq is not None:
            date_range = pd.date_range(start=self.start_date, end=self.max_val_date, freq=self.val_freq.value)
            _min_val_date = date_range[-2] if len(date_range) > 1 else date_range[-1]
        elif min_val_date is None and self.custom_val_periods is not None:
            _min_val_date = self.custom_val_periods[next(iter(self.custom_val_periods))].val_period[0]
        elif min_val_date is not None:
            _min_val_date = min_val_date
        else:
            raise ValueError("Cannot infer min_val_date")

        self._min_val_date = np.datetime64(_min_val_date)

        if not (self.start_date < self.min_val_date and self.min_val_date < self.max_val_date):
            raise ValueError(
                f"min_val_date {self.min_val_date} is not within date range [{self.start_date}, {self.max_val_date}]."
            )

        if self.custom_val_periods is not None:
            invalid_periods = {
                period_name: val_period
                for period_name, val_period in self.custom_val_periods.items()
                if val_period.val_period[1] < val_period.val_period[0]
            }
            if invalid_periods:
                raise ValueError(f"Invalid Validation Periods - {invalid_periods}")

        self.wfo_period_durations = None
        self.wfo_results = None

    @cached_property
    def dates(self) -> pd.DatetimeIndex:
        date_lookups = get_lookups(self.dataset.encodings["date"])
        assert date_lookups is not None
        datetime_index = pd.DatetimeIndex(date_lookups, name="dates", closed=True, freq="infer")
        if TYPE_CHECKING:
            assert isinstance(datetime_index, pd.DatetimeIndex)
        return datetime_index

    @property
    def val_freq(self) -> Frequency | None:
        """Frequency of WFO periods. If None, custom validation periods is used."""
        return self._val_freq

    @property
    def min_val_date(self) -> np.datetime64:
        return self._min_val_date

    @property
    def max_val_date(self) -> np.datetime64:
        return self._max_val_date

    @property
    def start_date(self) -> np.datetime64:
        return self._start_date

    def get_periods(self) -> dict[str, ValidationPeriod]:
        """
        Creates the periods WFO will run on.
        For each Frequency type, periods are generated.
        NOTE: the periods are built from date encodings and not from date_index. This can cause mismatch if the dataset
        was subsetted on time. Ensure the validation periods provided are within the date_index in dataset.
        When custom_val_periods is set then we return custom_val_periods!

        Returns:
            dict[str, ValidationPeriod]: Key is a unique name for a period. Contains validation range and start date.
        """
        if self.custom_val_periods is not None:
            return self.custom_val_periods

        dates: pd.DatetimeIndex = self.dates[self.dates >= self.start_date]
        val_dates = dates[(dates >= self.min_val_date) & (dates <= self.max_val_date)]
        if TYPE_CHECKING:
            assert self.val_freq is not None
        val_freq = self.val_freq
        periods: pd.Index | pd.PeriodIndex

        match (val_freq):
            case Frequency.half:
                periods = val_dates.map(lambda dt: f"{dt.year}-{val_freq.value}{(dt.quarter-1) // 2 % 2 + 1}")
            case Frequency.fortnight:
                # we dont use iso week due to weird 51,52,53 weeks. This makes it continuous.
                periods = val_dates.map(lambda dt: f"{dt.year}-{val_freq.value}{dt.dayofyear // 7 // 2 + 1}")
            case _:
                periods = val_dates.to_period(val_freq.value)

        # TODO (@legendof-selda) moving start date option.
        validation_periods = {
            str(period): ValidationPeriod(
                (val_dates[periods == period].min(), val_dates[periods == period].max()), self.start_date
            )
            for period in pd.unique(periods)
        }
        return validation_periods

    def get_model_names(self) -> dict[str, str]:
        def model_name(name_prefix: str, period: str) -> str:
            name = f"{name_prefix}_{str(period)}"
            return name.replace("-", "").replace("/", "_")

        return {period: model_name(self.model_name_prefix, period) for period in self.periods}

    @property
    def periods(self) -> tuple[str, ...]:
        return tuple(self.get_periods().keys())

    @staticmethod
    def stitch_period_outputs(
        wfo_results: dict[str, TrainTestOutput],
        encodings: Encodings,
    ) -> dict[str, TrainTestOutput]:
        # We need to sort the batches by index (based on level) as they are shuffled.
        # We could instead call inference without shuffling,
        # although it doesn't give us gaurantees and this step would make it safer.
        sorted_period_outputs = [
            utils.sort_batch_index(
                # we gather on val dates as stitch metrics matter on test only and this reduces memory
                output=utils.gather(
                    outputs.output,
                    outputs.val_dates_idx,
                    axis=1,
                    axis_type=Axis.Time,
                ),
                encodings=encodings,
            )
            for outputs in wfo_results.values()
        ]
        # now we concatenate on the Time Axis.
        time_concatenated_outputs = utils.concat_intermediaries(sorted_period_outputs, axis=1, axis_type=Axis.Time)
        stitched_val_dates_idx = []
        # The val_dates_idx is adjusted this way since after concatenation on time, all the time is part of validation.
        # here train metrics wouldn't be calculated and we only have test!
        # not all edge cases have been resolved with this change. This was done to fix the OOM issue which occurs when
        # the time axis is too huge.
        stitched_val_dates_idx = np.arange(_get_time_axis_length(time_concatenated_outputs))
        # NOTE: this is directly mutated.
        wfo_results["stitched"] = TrainTestOutput(time_concatenated_outputs, stitched_val_dates_idx)
        return wfo_results

    def run(
        self,
        epochs: int | None,
        *,
        validation_periods: dict[str, ValidationPeriod] | None = None,
        parallel: bool = False,
        save_dir: Path | None = DEFAULT_SAVE_DIR,
        include_stitched_period: bool = True,
        no_return: bool = False,
        callbacks_builder: Callable[[], CallbacksList] | None = None,
        options: WFOOptions = WFOOptions(),
        **kwargs,
    ) -> dict[str, TrainTestOutput] | None:
        """Run Walk forward optimization.

        Args:
            epochs (int): Number of epochs to train the model.
            validation_periods (dict[str, ValidationPeriod] | None, optional): WFO periods WFO will run on.
                If None, runs on `self.get_periods()`. Defaults to None.
            parallel (bool, optional): Run WFO in parallel based on GPU devices. Defaults to False.
            save_dir (Path | None, optional): Directory to save the model. `None` will not save the model.
                Defaults to DEFAULT_SAVE_DIR.
            calculate_trackers (bool, optional): Calculate and include trackers in output. Defaults to False.
            include_stitched_period (bool, optional): Include stitched results in output. Defaults to True.
            retrain (bool, optional): Train WFO from scratch or skipp training for stored models.
            disk_mode (bool, optional): Do not return outputs. Defaults to False.
                Should trigger `load_period_outputs` after `run_period` is completed.
            no_return (bool, optional): Do not return results. Special case for handling OOM. Defaults to False.
            callbacks_builder (Callable[[], CallbacksList] | None, optional): Function that returns CallbacksList
            smoothing_window (bool, optional): Smooth tail weeks data by appending additional weeks. Defaults to False.
            partial_checkpoint_enabled (bool | Literal["resume"] | Path, optional):
                For non parallel runs, load previous period model weights.
                If "resume" load existing initial period. If Path is provided load the given Path for first period only.
                Defaults to True.

        Returns:
            dict[str, TrainTestOutput]: Results of train_test_model for each period.
        """
        # TODO (@legendof-selda) save_dir attribute needs to be refactored properly
        if save_dir is not None:
            self.save_dir = save_dir
        validation_periods = self.get_periods() if validation_periods is None else validation_periods
        gpu_devices = tf.config.list_logical_devices("GPU")
        parallel = parallel and len(gpu_devices) > 1
        results: list[dict[str, TrainTestOutput]]
        model_names = self.get_model_names()
        # we want to disable caching on parallel mode.
        model_builder = _control_model_builder_cache(self.model_builder, parallel)

        # NOTE: If previous period is set then checkpoints must be deleted
        # else the will be bugs in the plots since the epochs will get mixed up
        delete_existing_checkpoints = (
            options.model_options.delete_existing_checkpoints or options.model_options.resume_from_previous_period
        )
        run_instructions = RunInstructions(
            epochs=epochs,
            save_dir=save_dir,
            include_trackers=options.model_options.include_trackers,
            calculate_trackers=options.calculate_trackers,
            retrain=options.retrain,
            disk_mode=options.disk_mode,
            checkpoint_freq=options.checkpoint_freq,
            callbacks_builder=callbacks_builder,
            smoothing_window=options.smoothing_window,
            delete_existing_checkpoints=delete_existing_checkpoints,
        )

        _run_period = partial(
            run_period, dataset=self.dataset, model_builder=model_builder, run_instructions=run_instructions, **kwargs
        )

        if not parallel:
            model_options = options.model_options
            function_outputs = []
            first_period = True
            previous_model_name = None
            for period, validation_period in validation_periods.items():
                name = model_names[period]
                checkpoint_path = model_options.checkpoint_path
                if checkpoint_path and not (checkpoint_path / "model.index").exists():
                    # checkpoint_path points to a wfo save_dir
                    checkpoint_path = checkpoint_path / name

                if first_period:
                    previous_model_name = checkpoint_path if model_options.load_initial_period else None
                    first_period = False
                elif not model_options.resume_from_previous_period:
                    previous_model_name = model_options.period_checkpoint_paths.get(name, checkpoint_path)

                function_output = _run_period(
                    period=period,
                    validation_period=validation_period,
                    name=name,
                    previous_model_name=previous_model_name,
                )
                previous_model_name = name
                function_outputs.append(function_output)
        else:
            function_outputs = Parallel(n_jobs=len(gpu_devices))(
                delayed(
                    utils.use_device(gpu_devices[i % len(gpu_devices)])(
                        _run_period(
                            period=period,
                            validation_period=validation_period,
                            name=model_names[period],
                        )
                    )
                )
                for i, (period, validation_period) in enumerate(validation_periods.items())
            )

        logger.info("WFO train complete.")
        if no_return:
            return

        if options.disk_mode:
            function_outputs = [load_period_output(path) for path in function_outputs]

        function_outputs: list[tuple[dict[str, TrainTestOutput], dict[str, float]]]
        results, period_durations = zip(*function_outputs)
        logger.info("Combining all wfo period results.")
        wfo_results = reduce(lambda d1, d2: d1 | d2, results)
        if epochs is not None and options.retrain is False:
            # NOTE: when learning curve is built, this gets overwritten.
            # workaround to deal with it.
            self.wfo_period_durations = reduce(lambda d1, d2: d1 | d2, period_durations)
        if include_stitched_period:
            wfo_results = WFORunner.stitch_period_outputs(wfo_results, self.dataset.encodings)
        logger.info("WFO Results ready.")
        return wfo_results

    def apply(
        self,
        apply_func: WFOPeriodApplyFunc[T],
        options: WFOOptions = WFOOptions(),
    ) -> Iterator[T]:
        assert self.save_dir is not None, f"Run WFO first to apply {apply_func}"
        run_instructions = RunInstructions(
            epochs=None,
            save_dir=self.save_dir,
            include_trackers=options.model_options.include_trackers,
            calculate_trackers=options.calculate_trackers,
            retrain=options.retrain,
            disk_mode=options.disk_mode,
            checkpoint_freq=options.checkpoint_freq,
            callbacks_builder=None,
            smoothing_window=options.smoothing_window,
            delete_existing_checkpoints=False,
            load_existing_model=True,
        )
        _apply_on_period_model = partial(
            apply_on_period_model,
            apply=apply_func,
            dataset=self.dataset,
            model_builder=self.model_builder,
            run_instructions=run_instructions,
        )
        val_periods = self.get_periods()
        for period, name in self.get_model_names().items():
            yield _apply_on_period_model(
                val_period=val_periods[period],
                period=period,
                name=name,
            )

    def calculate_metrics(
        self,
        results: dict[str, TrainTestOutput],
        mask: dict[str, NDArray[np.bool_]] | None = None,
        weights: dict[str, NDArray[np.float_]] | None | Literal["auto"] = None,
        level: tuple[str, ...] = ("brand", "wholesaler"),
        calculate_custom_metrics: bool = False,
    ) -> dict[str, GroupedMetrics | dict[str, GroupedMetrics]]:
        """Calculate metrics for WFO period results.
        Internally calls `calculate_metrics` in `wt_ml.tuning.train_test_model`

        Args:
            results (dict[str, TrainTestOutput]): Results of train_test_model for each period.
            mask (dict[str, NDArray[np.bool_]] | None, optional): Mask tensor of the same shape as y_true and y_pred
                indicating which elements to mask out for each period. Default is None.
            weights (dict[str, NDArray[np.float_]] | None | Literal["auto"], optional): Weights used for taking weighted
                mean on the metrics for each period. 'auto' will pick the `instability_loss_mult` from results.
                Defaults to None.
            level (tuple[str, ...] | str | None): The level at which we want to aggregate.
                'country' will aggregate it to all. `None` will use ('brand', 'wholesaler').
                Defaults to ('brand', 'wholesaler').
            calculate_custom_metrics (bool, optional): Calculate custom metrics as well. Defaults to False.

        Returns:
            dict[str, GroupedMetrics]: `Metrics` grouped as train and test for each WFO period.
        """
        if weights == "auto":
            weights = {
                period: (
                    output.output[
                        tuple(output.output.keys())[-1]
                    ].inputs.instability_loss_mult  # type:ignore [reportOptionalMemberAccess]
                    if isinstance(output.output, dict)
                    else output.output.inputs.instability_loss_mult
                )
                for period, output in results.items()
            }
        return {
            period: calculate_metrics(
                output.output,
                output.val_dates_idx,
                self.dataset.encodings,
                mask=None if mask is None else mask[period],
                weights=None if weights is None else weights[period],
                level=level,
                calculate_custom_metrics=calculate_custom_metrics,
            )
            for period, output in results.items()
        }

    def metrics_to_df(self, period_metrics: dict[str, GroupedMetrics | dict[str, GroupedMetrics]]) -> pd.DataFrame:
        """Convert `period_metrics` to a dataframe.

        Args:
            period_metrics (dict[str, GroupedMetrics | dict[str, GroupedMetrics]]):
                Metrics calculated from self.calculate_metrics

        Returns:
            pd.DataFrame: Metrics DataFrame.
        """
        # handling case where network is included in period_metrics.
        index_names = (
            ["period", "network", "dataset", "metric"]
            if isinstance(period_metrics[self.periods[0]], dict)
            else ["period", "dataset", "metric"]
        )
        return metrics_to_df(period_metrics, index_names=index_names)

    def calculate_curves_diff(
        self,
        results: dict[str, TrainTestOutput],
        return_df: bool = False,
        level: tuple[str, ...] = ("brand", "wholesaler"),
    ) -> dict[str, NDArray] | NDArray | pd.DataFrame:
        """Calculates curve differences between consecutive periods.

        Args:
            results (dict[str, TrainTestOutput]): Results of train_test_model for each period.
            return_df (bool, optional): Return the results as a dataframe.
            level (tuple[str, ...] | str | None): The level at which we want to aggregate.
                NOTE: This is not implemented yet. Don't change from defaults.
                'country' will aggregate it to all. `None` will use ('brand', 'wholesaler').
                Defaults to ('brand', 'wholesaler').

        Returns:
            dict[str, NDArray] | NDArray: curve diff calculated on each curve and agg on periods.
        """
        # TODO (@legendof-selda): create generic WFO period level metrics function like in train_test_model.
        # TODO (@legendof-selda): agg on given levels.
        # stitched period is not included here
        # We need to sort them as calculate_curves_diff doesn't handle it internally.
        # TODO (@legendof-selda): optimization to avoid sorting here but do within calculate_curves_diff.
        periods = self.periods
        sorted_results: dict[str, TrainTestOutput] = {
            period: TrainTestOutput(
                utils.sort_batch_index(
                    output=results[period].output,
                    encodings=self.dataset.encodings,
                    level=level,
                ),
                results[period].val_dates_idx,
            )
            for period in periods
        }
        dict_type = isinstance(sorted_results[periods[0]].output, dict)
        differences: dict[str, list[dict[str, NDArray]]] | list[dict[str, NDArray]] = (
            defaultdict(list) if dict_type else []
        )
        for prev, current in zip(periods[:-1], periods[1:]):
            prev_output = sorted_results[prev].output
            current_output = sorted_results[current].output
            if dict_type:
                for net in current_output.keys():
                    diffs = utils.calculate_curves_diff(prev_output[net], current_output[net], self.dataset.encodings)
                    differences[net].append(diffs)
            else:
                diffs = utils.calculate_curves_diff(prev_output, current_output, self.dataset.encodings)
                differences.append(diffs)

        # reduce on period axis.
        if dict_type:
            curve_diffs = {net: utils.process_period_curves_diff(diffs) for net, diffs in differences.items()}
        else:
            curve_diffs = utils.process_period_curves_diff(differences)

        if return_df:
            inputs = sorted_results[periods[0]].output
            inputs = inputs[next(iter(inputs.keys()))].inputs if dict_type else inputs.inputs
            group_index = utils.get_index(level=level, inputs=inputs, encodings=self.dataset.encodings)
            curve_diffs_df = pd.DataFrame(flatten(curve_diffs) if dict_type else curve_diffs, index=group_index).T
            curve_diffs_df.index.names = ["network", "curve_diff"] if dict_type else ["curve_diff"]
            return curve_diffs_df
        else:
            return curve_diffs

    def get_period_durations_df(self):
        return pd.DataFrame(self.wfo_period_durations, index=["training_time"]).T

    def update_period_durations_df(self):
        # TODO (@legendof-selda) this is a workaround to get the periods duration back.
        def load_duration(model_name: str):
            path = (self.save_dir / model_name) / "duration.json"
            with path.open("rb") as fp:
                return orjson.loads(fp.read())

        self.wfo_period_durations = {
            period: load_duration(model_name)[period] for period, model_name in self.get_model_names().items()
        }

val_freq: Frequency | None property

Frequency of WFO periods. If None, custom validation periods is used.

__init__(dataset, model_builder, val_freq=None, min_val_date=None, max_val_date=None, start_date=None, custom_val_periods=None, model_name_prefix='Model')

WFO object that creates periods and runs walk forward optimization. Either pass in val_freq or custom_val_periods to create wfo periods. The model_builder is used to create the model used for train_test_model. This is separated from WFO, so that it's cleaner and makes it easier during parallelization. You can use build_model from wt_ml.networks.model and create a partial function from it assigning hyperparameters and net_combination.

from functools import partial
from wt_ml.networks.model import build_model


model_builder = partial(
    _build_model,
    hyperparameters=hyperparameters,
    net_combination=net_combination,
)

Parameters:

Name Type Description Default
dataset EconomicDataset

The dataset object.

required
model_builder ModelBuilder

Function that creates the model and has the ModelBuilder signature.

required
val_freq Frequency | str | None

Validation period frequency. Defaults to None.

None
min_val_date str | datetime64 | None

Minimum validation week to include in periods. If None, ensures that there is at least 1 period to train on. Defaults to None.

None
max_val_date str | datetime64 | None

Maximum validation week to include in periods. Defaults to None.

None
start_date str | datetime64 | None

The start date after which we start training. Defaults to None.

None
custom_val_periods dict[str, ValidationPeriod] | Sequence[ValidationPeriod] | None

Provide custom validation periods instead of generating from val_freq. Defaults to None.

None
model_name_prefix str

Prefix name given when we build the model. Defaults to "Model".

'Model'
Source code in wt_ml/tuning/wfo_runner.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
def __init__(
    self,
    dataset: EconomicDataset,
    model_builder: ModelBuilder,
    val_freq: Frequency | str | None = None,
    min_val_date: str | np.datetime64 | None = None,
    max_val_date: str | np.datetime64 | None = None,
    start_date: str | np.datetime64 | None = None,
    custom_val_periods: dict[str, ValidationPeriod] | Sequence[ValidationPeriod] | None = None,
    model_name_prefix: str = "Model",
):
    """
    WFO object that creates periods and runs walk forward optimization.
    Either pass in val_freq or custom_val_periods to create wfo periods.
    The model_builder is used to create the model used for `train_test_model`.
    This is separated from WFO, so that it's cleaner and makes it easier during parallelization.
    You can use `build_model` from `wt_ml.networks.model` and create a partial function from it assigning
    `hyperparameters` and net_combination.

    ``` python
    from functools import partial
    from wt_ml.networks.model import build_model


    model_builder = partial(
        _build_model,
        hyperparameters=hyperparameters,
        net_combination=net_combination,
    )
    ```

    Args:
        dataset (EconomicDataset): The dataset object.
        model_builder (ModelBuilder): Function that creates the model and has the ModelBuilder signature.
        val_freq (Frequency | str | None): Validation period frequency. Defaults to None.
        min_val_date (str | np.datetime64 | None): Minimum validation week to include in periods.
            If None, ensures that there is at least 1 period to train on. Defaults to None.
        max_val_date (str | np.datetime64 | None, optional): Maximum validation week to include in periods.
                                                            Defaults to None.
        start_date (str | np.datetime64 | None, optional): The start date after which we start training.
                                                            Defaults to None.
        custom_val_periods (dict[str, ValidationPeriod] | Sequence[ValidationPeriod] | None):
            Provide custom validation periods instead of generating from `val_freq `. Defaults to None.
        model_name_prefix (str, optional): Prefix name given when we build the model. Defaults to "Model".
    """
    # TODO (@legendof-selda): this should be done via a classmethod instead?
    if custom_val_periods is None and val_freq is None:
        raise ValueError("Provide either custom validation periods or validation frequency to run WFO periods.")

    self.model_name_prefix = model_name_prefix
    self.dataset = dataset
    self.model_builder = model_builder
    if val_freq is not None and not isinstance(val_freq, Frequency):
        val_freq = Frequency[val_freq]

    self._val_freq = val_freq
    if isinstance(custom_val_periods, Sequence):
        custom_val_periods = {f"period_{i}": val_period for i, val_period in enumerate(custom_val_periods)}

    self.custom_val_periods: dict[str, ValidationPeriod] | None = custom_val_periods
    self._max_val_date = np.datetime64(max_val_date if max_val_date else self.dates[-1])
    # date_index can start anywhere, dates will always have all the dates.
    # we also assume that date_index is always sorted in asc order.
    start_date_index = int(next(self.dataset()).date_index.numpy()[0])
    self._start_date = np.datetime64(start_date if start_date else self.dates[start_date_index])

    if min_val_date is None and self.val_freq is not None:
        date_range = pd.date_range(start=self.start_date, end=self.max_val_date, freq=self.val_freq.value)
        _min_val_date = date_range[-2] if len(date_range) > 1 else date_range[-1]
    elif min_val_date is None and self.custom_val_periods is not None:
        _min_val_date = self.custom_val_periods[next(iter(self.custom_val_periods))].val_period[0]
    elif min_val_date is not None:
        _min_val_date = min_val_date
    else:
        raise ValueError("Cannot infer min_val_date")

    self._min_val_date = np.datetime64(_min_val_date)

    if not (self.start_date < self.min_val_date and self.min_val_date < self.max_val_date):
        raise ValueError(
            f"min_val_date {self.min_val_date} is not within date range [{self.start_date}, {self.max_val_date}]."
        )

    if self.custom_val_periods is not None:
        invalid_periods = {
            period_name: val_period
            for period_name, val_period in self.custom_val_periods.items()
            if val_period.val_period[1] < val_period.val_period[0]
        }
        if invalid_periods:
            raise ValueError(f"Invalid Validation Periods - {invalid_periods}")

    self.wfo_period_durations = None
    self.wfo_results = None

calculate_curves_diff(results, return_df=False, level=('brand', 'wholesaler'))

Calculates curve differences between consecutive periods.

Parameters:

Name Type Description Default
results dict[str, TrainTestOutput]

Results of train_test_model for each period.

required
return_df bool

Return the results as a dataframe.

False
level tuple[str, ...] | str | None

The level at which we want to aggregate. NOTE: This is not implemented yet. Don't change from defaults. 'country' will aggregate it to all. None will use ('brand', 'wholesaler'). Defaults to ('brand', 'wholesaler').

('brand', 'wholesaler')

Returns:

Type Description
dict[str, NDArray] | NDArray | DataFrame

dict[str, NDArray] | NDArray: curve diff calculated on each curve and agg on periods.

Source code in wt_ml/tuning/wfo_runner.py
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
def calculate_curves_diff(
    self,
    results: dict[str, TrainTestOutput],
    return_df: bool = False,
    level: tuple[str, ...] = ("brand", "wholesaler"),
) -> dict[str, NDArray] | NDArray | pd.DataFrame:
    """Calculates curve differences between consecutive periods.

    Args:
        results (dict[str, TrainTestOutput]): Results of train_test_model for each period.
        return_df (bool, optional): Return the results as a dataframe.
        level (tuple[str, ...] | str | None): The level at which we want to aggregate.
            NOTE: This is not implemented yet. Don't change from defaults.
            'country' will aggregate it to all. `None` will use ('brand', 'wholesaler').
            Defaults to ('brand', 'wholesaler').

    Returns:
        dict[str, NDArray] | NDArray: curve diff calculated on each curve and agg on periods.
    """
    # TODO (@legendof-selda): create generic WFO period level metrics function like in train_test_model.
    # TODO (@legendof-selda): agg on given levels.
    # stitched period is not included here
    # We need to sort them as calculate_curves_diff doesn't handle it internally.
    # TODO (@legendof-selda): optimization to avoid sorting here but do within calculate_curves_diff.
    periods = self.periods
    sorted_results: dict[str, TrainTestOutput] = {
        period: TrainTestOutput(
            utils.sort_batch_index(
                output=results[period].output,
                encodings=self.dataset.encodings,
                level=level,
            ),
            results[period].val_dates_idx,
        )
        for period in periods
    }
    dict_type = isinstance(sorted_results[periods[0]].output, dict)
    differences: dict[str, list[dict[str, NDArray]]] | list[dict[str, NDArray]] = (
        defaultdict(list) if dict_type else []
    )
    for prev, current in zip(periods[:-1], periods[1:]):
        prev_output = sorted_results[prev].output
        current_output = sorted_results[current].output
        if dict_type:
            for net in current_output.keys():
                diffs = utils.calculate_curves_diff(prev_output[net], current_output[net], self.dataset.encodings)
                differences[net].append(diffs)
        else:
            diffs = utils.calculate_curves_diff(prev_output, current_output, self.dataset.encodings)
            differences.append(diffs)

    # reduce on period axis.
    if dict_type:
        curve_diffs = {net: utils.process_period_curves_diff(diffs) for net, diffs in differences.items()}
    else:
        curve_diffs = utils.process_period_curves_diff(differences)

    if return_df:
        inputs = sorted_results[periods[0]].output
        inputs = inputs[next(iter(inputs.keys()))].inputs if dict_type else inputs.inputs
        group_index = utils.get_index(level=level, inputs=inputs, encodings=self.dataset.encodings)
        curve_diffs_df = pd.DataFrame(flatten(curve_diffs) if dict_type else curve_diffs, index=group_index).T
        curve_diffs_df.index.names = ["network", "curve_diff"] if dict_type else ["curve_diff"]
        return curve_diffs_df
    else:
        return curve_diffs

calculate_metrics(results, mask=None, weights=None, level=('brand', 'wholesaler'), calculate_custom_metrics=False)

Calculate metrics for WFO period results. Internally calls calculate_metrics in wt_ml.tuning.train_test_model

Parameters:

Name Type Description Default
results dict[str, TrainTestOutput]

Results of train_test_model for each period.

required
mask dict[str, NDArray[bool_]] | None

Mask tensor of the same shape as y_true and y_pred indicating which elements to mask out for each period. Default is None.

None
weights dict[str, NDArray[float_]] | None | Literal['auto']

Weights used for taking weighted mean on the metrics for each period. 'auto' will pick the instability_loss_mult from results. Defaults to None.

None
level tuple[str, ...] | str | None

The level at which we want to aggregate. 'country' will aggregate it to all. None will use ('brand', 'wholesaler'). Defaults to ('brand', 'wholesaler').

('brand', 'wholesaler')
calculate_custom_metrics bool

Calculate custom metrics as well. Defaults to False.

False

Returns:

Type Description
dict[str, GroupedMetrics | dict[str, GroupedMetrics]]

dict[str, GroupedMetrics]: Metrics grouped as train and test for each WFO period.

Source code in wt_ml/tuning/wfo_runner.py
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def calculate_metrics(
    self,
    results: dict[str, TrainTestOutput],
    mask: dict[str, NDArray[np.bool_]] | None = None,
    weights: dict[str, NDArray[np.float_]] | None | Literal["auto"] = None,
    level: tuple[str, ...] = ("brand", "wholesaler"),
    calculate_custom_metrics: bool = False,
) -> dict[str, GroupedMetrics | dict[str, GroupedMetrics]]:
    """Calculate metrics for WFO period results.
    Internally calls `calculate_metrics` in `wt_ml.tuning.train_test_model`

    Args:
        results (dict[str, TrainTestOutput]): Results of train_test_model for each period.
        mask (dict[str, NDArray[np.bool_]] | None, optional): Mask tensor of the same shape as y_true and y_pred
            indicating which elements to mask out for each period. Default is None.
        weights (dict[str, NDArray[np.float_]] | None | Literal["auto"], optional): Weights used for taking weighted
            mean on the metrics for each period. 'auto' will pick the `instability_loss_mult` from results.
            Defaults to None.
        level (tuple[str, ...] | str | None): The level at which we want to aggregate.
            'country' will aggregate it to all. `None` will use ('brand', 'wholesaler').
            Defaults to ('brand', 'wholesaler').
        calculate_custom_metrics (bool, optional): Calculate custom metrics as well. Defaults to False.

    Returns:
        dict[str, GroupedMetrics]: `Metrics` grouped as train and test for each WFO period.
    """
    if weights == "auto":
        weights = {
            period: (
                output.output[
                    tuple(output.output.keys())[-1]
                ].inputs.instability_loss_mult  # type:ignore [reportOptionalMemberAccess]
                if isinstance(output.output, dict)
                else output.output.inputs.instability_loss_mult
            )
            for period, output in results.items()
        }
    return {
        period: calculate_metrics(
            output.output,
            output.val_dates_idx,
            self.dataset.encodings,
            mask=None if mask is None else mask[period],
            weights=None if weights is None else weights[period],
            level=level,
            calculate_custom_metrics=calculate_custom_metrics,
        )
        for period, output in results.items()
    }

get_periods()

Creates the periods WFO will run on. For each Frequency type, periods are generated. NOTE: the periods are built from date encodings and not from date_index. This can cause mismatch if the dataset was subsetted on time. Ensure the validation periods provided are within the date_index in dataset. When custom_val_periods is set then we return custom_val_periods!

Returns:

Type Description
dict[str, ValidationPeriod]

dict[str, ValidationPeriod]: Key is a unique name for a period. Contains validation range and start date.

Source code in wt_ml/tuning/wfo_runner.py
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def get_periods(self) -> dict[str, ValidationPeriod]:
    """
    Creates the periods WFO will run on.
    For each Frequency type, periods are generated.
    NOTE: the periods are built from date encodings and not from date_index. This can cause mismatch if the dataset
    was subsetted on time. Ensure the validation periods provided are within the date_index in dataset.
    When custom_val_periods is set then we return custom_val_periods!

    Returns:
        dict[str, ValidationPeriod]: Key is a unique name for a period. Contains validation range and start date.
    """
    if self.custom_val_periods is not None:
        return self.custom_val_periods

    dates: pd.DatetimeIndex = self.dates[self.dates >= self.start_date]
    val_dates = dates[(dates >= self.min_val_date) & (dates <= self.max_val_date)]
    if TYPE_CHECKING:
        assert self.val_freq is not None
    val_freq = self.val_freq
    periods: pd.Index | pd.PeriodIndex

    match (val_freq):
        case Frequency.half:
            periods = val_dates.map(lambda dt: f"{dt.year}-{val_freq.value}{(dt.quarter-1) // 2 % 2 + 1}")
        case Frequency.fortnight:
            # we dont use iso week due to weird 51,52,53 weeks. This makes it continuous.
            periods = val_dates.map(lambda dt: f"{dt.year}-{val_freq.value}{dt.dayofyear // 7 // 2 + 1}")
        case _:
            periods = val_dates.to_period(val_freq.value)

    # TODO (@legendof-selda) moving start date option.
    validation_periods = {
        str(period): ValidationPeriod(
            (val_dates[periods == period].min(), val_dates[periods == period].max()), self.start_date
        )
        for period in pd.unique(periods)
    }
    return validation_periods

metrics_to_df(period_metrics)

Convert period_metrics to a dataframe.

Parameters:

Name Type Description Default
period_metrics dict[str, GroupedMetrics | dict[str, GroupedMetrics]]

Metrics calculated from self.calculate_metrics

required

Returns:

Type Description
DataFrame

pd.DataFrame: Metrics DataFrame.

Source code in wt_ml/tuning/wfo_runner.py
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
def metrics_to_df(self, period_metrics: dict[str, GroupedMetrics | dict[str, GroupedMetrics]]) -> pd.DataFrame:
    """Convert `period_metrics` to a dataframe.

    Args:
        period_metrics (dict[str, GroupedMetrics | dict[str, GroupedMetrics]]):
            Metrics calculated from self.calculate_metrics

    Returns:
        pd.DataFrame: Metrics DataFrame.
    """
    # handling case where network is included in period_metrics.
    index_names = (
        ["period", "network", "dataset", "metric"]
        if isinstance(period_metrics[self.periods[0]], dict)
        else ["period", "dataset", "metric"]
    )
    return metrics_to_df(period_metrics, index_names=index_names)

run(epochs, *, validation_periods=None, parallel=False, save_dir=DEFAULT_SAVE_DIR, include_stitched_period=True, no_return=False, callbacks_builder=None, options=WFOOptions(), **kwargs)

Run Walk forward optimization.

Parameters:

Name Type Description Default
epochs int

Number of epochs to train the model.

required
validation_periods dict[str, ValidationPeriod] | None

WFO periods WFO will run on. If None, runs on self.get_periods(). Defaults to None.

None
parallel bool

Run WFO in parallel based on GPU devices. Defaults to False.

False
save_dir Path | None

Directory to save the model. None will not save the model. Defaults to DEFAULT_SAVE_DIR.

DEFAULT_SAVE_DIR
calculate_trackers bool

Calculate and include trackers in output. Defaults to False.

required
include_stitched_period bool

Include stitched results in output. Defaults to True.

True
retrain bool

Train WFO from scratch or skipp training for stored models.

required
disk_mode bool

Do not return outputs. Defaults to False. Should trigger load_period_outputs after run_period is completed.

required
no_return bool

Do not return results. Special case for handling OOM. Defaults to False.

False
callbacks_builder Callable[[], CallbacksList] | None

Function that returns CallbacksList

None
smoothing_window bool

Smooth tail weeks data by appending additional weeks. Defaults to False.

required
partial_checkpoint_enabled bool | Literal['resume'] | Path

For non parallel runs, load previous period model weights. If "resume" load existing initial period. If Path is provided load the given Path for first period only. Defaults to True.

required

Returns:

Type Description
dict[str, TrainTestOutput] | None

dict[str, TrainTestOutput]: Results of train_test_model for each period.

Source code in wt_ml/tuning/wfo_runner.py
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
def run(
    self,
    epochs: int | None,
    *,
    validation_periods: dict[str, ValidationPeriod] | None = None,
    parallel: bool = False,
    save_dir: Path | None = DEFAULT_SAVE_DIR,
    include_stitched_period: bool = True,
    no_return: bool = False,
    callbacks_builder: Callable[[], CallbacksList] | None = None,
    options: WFOOptions = WFOOptions(),
    **kwargs,
) -> dict[str, TrainTestOutput] | None:
    """Run Walk forward optimization.

    Args:
        epochs (int): Number of epochs to train the model.
        validation_periods (dict[str, ValidationPeriod] | None, optional): WFO periods WFO will run on.
            If None, runs on `self.get_periods()`. Defaults to None.
        parallel (bool, optional): Run WFO in parallel based on GPU devices. Defaults to False.
        save_dir (Path | None, optional): Directory to save the model. `None` will not save the model.
            Defaults to DEFAULT_SAVE_DIR.
        calculate_trackers (bool, optional): Calculate and include trackers in output. Defaults to False.
        include_stitched_period (bool, optional): Include stitched results in output. Defaults to True.
        retrain (bool, optional): Train WFO from scratch or skipp training for stored models.
        disk_mode (bool, optional): Do not return outputs. Defaults to False.
            Should trigger `load_period_outputs` after `run_period` is completed.
        no_return (bool, optional): Do not return results. Special case for handling OOM. Defaults to False.
        callbacks_builder (Callable[[], CallbacksList] | None, optional): Function that returns CallbacksList
        smoothing_window (bool, optional): Smooth tail weeks data by appending additional weeks. Defaults to False.
        partial_checkpoint_enabled (bool | Literal["resume"] | Path, optional):
            For non parallel runs, load previous period model weights.
            If "resume" load existing initial period. If Path is provided load the given Path for first period only.
            Defaults to True.

    Returns:
        dict[str, TrainTestOutput]: Results of train_test_model for each period.
    """
    # TODO (@legendof-selda) save_dir attribute needs to be refactored properly
    if save_dir is not None:
        self.save_dir = save_dir
    validation_periods = self.get_periods() if validation_periods is None else validation_periods
    gpu_devices = tf.config.list_logical_devices("GPU")
    parallel = parallel and len(gpu_devices) > 1
    results: list[dict[str, TrainTestOutput]]
    model_names = self.get_model_names()
    # we want to disable caching on parallel mode.
    model_builder = _control_model_builder_cache(self.model_builder, parallel)

    # NOTE: If previous period is set then checkpoints must be deleted
    # else the will be bugs in the plots since the epochs will get mixed up
    delete_existing_checkpoints = (
        options.model_options.delete_existing_checkpoints or options.model_options.resume_from_previous_period
    )
    run_instructions = RunInstructions(
        epochs=epochs,
        save_dir=save_dir,
        include_trackers=options.model_options.include_trackers,
        calculate_trackers=options.calculate_trackers,
        retrain=options.retrain,
        disk_mode=options.disk_mode,
        checkpoint_freq=options.checkpoint_freq,
        callbacks_builder=callbacks_builder,
        smoothing_window=options.smoothing_window,
        delete_existing_checkpoints=delete_existing_checkpoints,
    )

    _run_period = partial(
        run_period, dataset=self.dataset, model_builder=model_builder, run_instructions=run_instructions, **kwargs
    )

    if not parallel:
        model_options = options.model_options
        function_outputs = []
        first_period = True
        previous_model_name = None
        for period, validation_period in validation_periods.items():
            name = model_names[period]
            checkpoint_path = model_options.checkpoint_path
            if checkpoint_path and not (checkpoint_path / "model.index").exists():
                # checkpoint_path points to a wfo save_dir
                checkpoint_path = checkpoint_path / name

            if first_period:
                previous_model_name = checkpoint_path if model_options.load_initial_period else None
                first_period = False
            elif not model_options.resume_from_previous_period:
                previous_model_name = model_options.period_checkpoint_paths.get(name, checkpoint_path)

            function_output = _run_period(
                period=period,
                validation_period=validation_period,
                name=name,
                previous_model_name=previous_model_name,
            )
            previous_model_name = name
            function_outputs.append(function_output)
    else:
        function_outputs = Parallel(n_jobs=len(gpu_devices))(
            delayed(
                utils.use_device(gpu_devices[i % len(gpu_devices)])(
                    _run_period(
                        period=period,
                        validation_period=validation_period,
                        name=model_names[period],
                    )
                )
            )
            for i, (period, validation_period) in enumerate(validation_periods.items())
        )

    logger.info("WFO train complete.")
    if no_return:
        return

    if options.disk_mode:
        function_outputs = [load_period_output(path) for path in function_outputs]

    function_outputs: list[tuple[dict[str, TrainTestOutput], dict[str, float]]]
    results, period_durations = zip(*function_outputs)
    logger.info("Combining all wfo period results.")
    wfo_results = reduce(lambda d1, d2: d1 | d2, results)
    if epochs is not None and options.retrain is False:
        # NOTE: when learning curve is built, this gets overwritten.
        # workaround to deal with it.
        self.wfo_period_durations = reduce(lambda d1, d2: d1 | d2, period_durations)
    if include_stitched_period:
        wfo_results = WFORunner.stitch_period_outputs(wfo_results, self.dataset.encodings)
    logger.info("WFO Results ready.")
    return wfo_results

calculate_metrics(output, val_dates_idx, encodings, mask=None, weights=None, level=('brand', 'wholesaler'), calculate_custom_metrics=False)

Calculate metrics for given output. If output is a dict, we recursively calculate metrics for each key in the dict.

Parameters:

Name Type Description Default
output ModelOutputType

Outputs of the model.

required
val_dates_idx NDArray[int64]

The dates_index indices which is in validation period.

required
encodings dict[str, dict[str, int]]

Encodings dict which will be used to decode the index values.

required
mask NDArray[bool_] | None

Mask tensor of the same shape as y_true and y_pred indicating which elements to mask out. Default is None.

None
weights NDArray[float_] | None

Weights used for taking weighted mean on the metrics. Defaults to None.

None
level tuple[str, ...] | str | None

The level at which we want to aggregate it or the metrics correspond to. 'country' will aggregate it to all. None will use ('brand', 'wholesaler'). Defaults to ('brand', 'wholesaler').

('brand', 'wholesaler')
calculate_custom_metrics bool

Calculate custom metrics as well. Defaults to False.

False

Returns:

Name Type Description
GroupedMetrics GroupedMetrics | dict[str, GroupedMetrics]

Metrics grouped as full, train and test.

Source code in wt_ml/tuning/train_test_model.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def calculate_metrics(
    output: ModelOutputType,
    val_dates_idx: NDArray[np.int64],
    encodings: Encodings,
    mask: NDArray[np.bool_] | None = None,
    weights: NDArray[np.float_] | None = None,
    level: tuple[str, ...] | str | None = ("brand", "wholesaler"),
    calculate_custom_metrics: bool = False,
) -> GroupedMetrics | dict[str, GroupedMetrics]:
    """Calculate metrics for given output.
    If output is a dict, we recursively calculate metrics for each key in the dict.

    Args:
        output (ModelOutputType): Outputs of the model.
        val_dates_idx (NDArray[np.int64]): The dates_index indices which is in validation period.
        encodings (dict[str, dict[str, int]]): Encodings dict which will be used to decode the index values.
        mask (NDArray[np.bool_] | None, optional): Mask tensor of the same shape as y_true and y_pred
            indicating which elements to mask out. Default is None.
        weights (NDArray[np.float_] | None, optional): Weights used for taking weighted mean on the metrics.
            Defaults to None.
        level (tuple[str, ...] | str | None): The level at which we want to aggregate it or the metrics correspond to.
            'country' will aggregate it to all. `None` will use ('brand', 'wholesaler').
            Defaults to ('brand', 'wholesaler').
        calculate_custom_metrics (bool, optional): Calculate custom metrics as well. Defaults to False.

    Returns:
        GroupedMetrics: `Metrics` grouped as full, train and test.
    """
    if isinstance(output, dict):
        return {
            key: calculate_metrics(
                output[key],
                val_dates_idx,
                encodings,
                mask=mask,
                weights=weights,
                level=level,
                calculate_custom_metrics=calculate_custom_metrics,
            )
            for key in output.keys()
        }

    if TYPE_CHECKING:
        assert isinstance(output, EconomicIntermediaries)
        assert output.inputs is not None

    if mask is None:
        mask = np.array(output.mask)

    # batch, time
    train_mask = np.array(mask)
    train_dates_idx = np.setdiff1d(np.arange(mask.shape[1]), val_dates_idx, assume_unique=True)
    train_mask[:, val_dates_idx] = 0.0
    test_mask = np.array(mask)
    test_mask[:, train_dates_idx] = 0.0

    group_index = utils.get_index(level=level, inputs=output.inputs, encodings=encodings)

    if weights is not None:
        # weights will aggregate the metrics to a scalar.
        agg_index = pd.Index(["aggregated"])
    elif len(group_index.names) == 1:
        agg_index = pd.Index(pd.unique(group_index), name=group_index.names[0])
    else:
        # all the metrics are aggregated in this index
        # NOTE: currently we do not sort the values based on index!
        # if we are sorting on index then this should be changed.
        agg_index = pd.MultiIndex.from_tuples(pd.unique(group_index), names=group_index.names)

    norm_factor = encodings.get("normalization_factor", 1.0)
    get_metrics = partial(
        _get_metrics, norm_factor=norm_factor, weights=weights, calculate_custom_metrics=calculate_custom_metrics
    )
    return GroupedMetrics(
        train=get_metrics(group_index, output, train_mask),
        test=get_metrics(group_index, output, test_mask),
        index=agg_index,
    )

concat_intermediaries(intermediaries, axis=0, axis_type=Axis.Batch)

Concatenates tf.experimental.ExtensionType objects in the batch_axis. Make sure that both axis and axis_type matches else it leads to unexpected results.

Parameters:

Name Type Description Default
intermediaries list[ExtensionType]

tf.experimental.ExtensionType type instance.

required
axis int

The axis where we need to concatenate. Defaults to 0 which is assumed to be batch.

0
axis_type(Axis, optional

The annotated axis where we need to concatenate. Defaults to Axis.Batch.

required

Returns:

Type Description
T

tf.experimental.ExtensionType: The concatenated output which is the same type that is passed in.

Source code in wt_ml/tuning/utils.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@in_cpu
@warn_once(ConcatWarning)
def concat_intermediaries(
    intermediaries: list[T],
    axis: int = 0,
    axis_type: Axis = Axis.Batch,
) -> T:
    """Concatenates tf.experimental.ExtensionType objects in the batch_axis.
    Make sure that both `axis` and `axis_type` matches else it leads to unexpected results.

    Args:
        intermediaries (list[tf.experimental.ExtensionType]): tf.experimental.ExtensionType type instance.
        axis (int, optional): The axis where we need to concatenate. Defaults to 0 which is assumed to be batch.
        axis_type(Axis, optional): The annotated axis where we need to concatenate. Defaults to Axis.Batch.

    Returns:
        tf.experimental.ExtensionType: The concatenated output which is the same type that is passed in.
    """
    if not hasattr(intermediaries, "__len__") or len(intermediaries) < 1:
        # when we don't pass a list. Only avoid an error.
        return intermediaries
    elif len(intermediaries) == 1:
        # Handling edge case where its full batch.
        return intermediaries[0]

    OutputType = type(intermediaries[0])
    annotations: dict[str, ...] = getattr(OutputType, "annotations", getattr(OutputType, "__annotations__", {}))
    concated_intermediaries = {}
    # ExtensionType supports Mapping, so this detects that for us. else we assume its an ExtensionType
    dict_type = issubclass(OutputType, (dict, Mapping))
    attributes = intermediaries[0].keys() if dict_type else vars(intermediaries[0])
    for attr in attributes:
        annotation = annotations.get(attr, None)
        batch_values = [batch.get(attr) if dict_type else getattr(batch, attr) for batch in intermediaries]
        value = batch_values[0]
        if value is None:
            concatenated_values = None
        elif isinstance(value, (tf.experimental.ExtensionType, dict, Mapping, SimpleNamespace)):
            concatenated_values = concat_intermediaries(batch_values, axis=axis, axis_type=axis_type)
        else:
            concatenated_values = _concatenate_intermediaries_leaf(
                annotation, batch_values, value, axis, axis_type, attr
            )
        concated_intermediaries[attr] = concatenated_values

    return OutputType(**concated_intermediaries)

metrics_to_df(metrics_data, index_names=['dataset', 'metric'])

Convert given GroupedMetrics to a dataframe.

Parameters:

Name Type Description Default
metrics_data GroupedMetrics | dict[str, GroupedMetrics]

Calculated GroupedMetrics.

required
index_names list

The index names for the dataframe. Defaults to ["dataset", "metric"]. The GroupedMetrics are flattened, so the index is usually (*any_parent_levels, 'dataset', 'metric').

['dataset', 'metric']

Returns:

Type Description
DataFrame

pd.DataFrame: Metrics pandas dataframe.

Source code in wt_ml/tuning/train_test_model.py
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
def metrics_to_df(
    metrics_data: GroupedMetrics | dict[str, GroupedMetrics],
    index_names=["dataset", "metric"],
) -> pd.DataFrame:
    """Convert given GroupedMetrics to a dataframe.

    Args:
        metrics_data (GroupedMetrics | dict[str, GroupedMetrics]): Calculated GroupedMetrics.
        index_names (list, optional): The index names for the dataframe. Defaults to ["dataset", "metric"].
            The GroupedMetrics are flattened, so the index is usually (*any_parent_levels, 'dataset', 'metric').

    Returns:
        pd.DataFrame: Metrics pandas dataframe.
    """
    if isinstance(metrics_data, dict):
        df = pd.concat(
            [
                pd.DataFrame(flatten({key: _process_grouped_metrics(met)}), index=met.index).T
                for key, met in flatten(metrics_data).items()
            ],
            axis=0,
        )
    else:
        index = metrics_data.index
        metrics_data = flatten(_process_grouped_metrics(metrics_data))
        df = pd.DataFrame(metrics_data, index=index).T

    # initial levels are flattened to a tuple. we need unflatten it again so it doesn't become like ((...), ...).
    # we don't use unflatten from flatten_dict as metrics_data consumes a lot of memory.
    df.index = pd.MultiIndex.from_tuples(index_flatten(df.index))

    if len(df.index.names) > len(index_names):
        index_names = [None for _ in range(len(df.index.names) - len(index_names))] + index_names

    df.index.set_names(index_names, inplace=True)
    return df