OutputImpact

Bases: ModelOutputItem

Class to visualize the impacts and sales

Source code in wt_ml/output/output_impact.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
class OutputImpact(ModelOutputItem):
    """Class to visualize the impacts and sales"""

    def __init__(
        self,
        encodings: Encodings | None = None,
        intermediaries: EconomicIntermediaries | list[EconomicIntermediaries] | None = None,
        separate_decay: bool = False,
        separate_dfs: bool = False,
        combine_granularities_flag: bool = False,
        total_impact_from_date: bool = False,
        collapse_level: CollapseLevel | dict[str, CollapseLevel] = 0,
        decompose_level: list[str] | None = None,
        show_specific: list[str] | None = None,
        with_groups: bool = False,
        collapse_lead_lag: bool | None = None,
        use_granularities: Sequence[str] | None = None,
        linearization_method: LinearizationMethod = "logspace",
        apply_sales_mask: bool = True,
        final_df: pd.DataFrame | list[pd.DataFrame] | None = None,
        is_animation_call: bool = False,
    ):
        super().__init__(final_df, [encodings, intermediaries])
        self.final_df = final_df
        self.separate_dfs = separate_dfs
        self.is_animation_call = is_animation_call

        if self.final_df is None:
            assert encodings is not None, "Encodings is required"
            self.encodings = encodings
            self.intermediaries: list[EconomicIntermediaries] = [
                RecursiveNamespace.parse(inter) if isinstance(inter, Mapping) else inter
                for inter in (intermediaries if isinstance(intermediaries, list) else [intermediaries])
            ]
            self.separate_decay = separate_decay
            self.combine_granularities_flag = combine_granularities_flag
            self.total_impact_from_date = total_impact_from_date
            if isinstance(collapse_level, dict):
                collapse_level = CollapseDict(collapse_level)

            self.collapse_level = collapse_level
            self.decompose_level = decompose_level
            self.show_specific = show_specific
            self.with_groups = with_groups
            if self.with_groups and self.collapse_level >= 2:
                self.with_groups = False
            # TODO (@Debarcha Mitra) map_signals uses collapse_lead_lags. Refactor these later to maintain consistency.
            self.collapse_lead_lag = collapse_lead_lag
            self.use_granularities = use_granularities
            self.linearization_method = linearization_method
            self.apply_sales_mask = apply_sales_mask

    def get_impacts_df(
        self, intermediary: EconomicIntermediaries
    ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        # index is timestamps
        # columns is a multiindex (
        # granularities string concatenated together,
        # [signals + True Sales + Predicted Sales + 'steps' for animation frames])
        if intermediary.impacts.roicurve is not None:
            additive_sublabels = (
                [to_signal_names(intermediary.impacts.roicurve.signal_names)]
                + [to_signal_names(intermediary.impacts.roicurve.betagamma.signal_names)]
                if self.separate_decay
                else [to_signal_names(intermediary.impacts.roicurve.signal_names)]
            )
            additive_impacts = get_additive_impacts(
                to_numpy(intermediary.impacts.roicurve.betagamma.impact_by_signal_total),
                to_numpy(intermediary.impacts.roicurve.impact_by_signal_instant),
                to_numpy(intermediary.impacts.roicurve.betagamma.impact_by_signal),
                to_numpy(intermediary.impacts.roicurve.impact_by_signal),
                self.total_impact_from_date,
                self.separate_decay,
            )
        else:
            additive_sublabels = []
            additive_impacts = []

        # TODO (@legendof-selda): This needs to be refactored so that we don't set it explicitly.
        # when new layers are added, we are forced to set it here.
        # create additive and multiplicative property for EconomicIntermediaries
        multiplicative_intermediary_layers: tuple[tuple[Module, ...], ...] = (
            # Distribution gets merged with baseline effectively
            (intermediary.impacts.distribution,),
            # Pricing related things since we control it we want to apply before things we can't
            # control so it is easier to see how our choices are affecting sales
            (
                intermediary.impacts.pricing,
                intermediary.impacts.pricing_lead_lag_me,
                intermediary.impacts.price_ratio,
            ),
            # Extreme events
            (
                intermediary.impacts.bud_light_effect,
                intermediary.impacts.covid_effect,
                intermediary.impacts.pre_investment_effect,
                intermediary.impacts.drop_hold,
            ),
            # Trend lines of external economic factors
            (
                intermediary.impacts.national_trend_me,
                intermediary.impacts.regional_trend_me,
                intermediary.impacts.global_me,
            ),
            # Seasonality
            (
                intermediary.impacts.holiday_me,
                intermediary.impacts.periodic_me,
            ),
            # Weather applied last so it doesn't give others a false sense of seasonality
            (
                intermediary.impacts.weather_me,
                intermediary.impacts.temperature_me,
            ),
        )
        baseline = to_numpy(intermediary.baseline)
        cur_baseline = baseline.copy()
        linearized_mult_impacts_arr = []
        multiplicative_sublabels = []
        for multiplicative_intermediaries in multiplicative_intermediary_layers:
            multiplicative_impacts, multiplicative_layer_sublabels = zip(
                *(
                    (to_numpy(inter.impact_by_signal), to_signal_names(inter.signal_names))
                    for inter in multiplicative_intermediaries
                    if inter is not None
                )
            )
            if len(multiplicative_impacts) == 0:
                continue
            multiplicative_sublabels = [*multiplicative_sublabels, *multiplicative_layer_sublabels]
            mult_layer_impacts = np.concatenate(multiplicative_impacts, axis=2)
            # Note: Values are different when they are in CPU and in GPU. Generally the diff is about 1e-6 for 1.0
            mult_layer_impacts = np.where(np.abs(mult_layer_impacts - 1) <= 1e-6, 1, mult_layer_impacts)
            linearized_mult_layer_impacts = linearize_multiplicative_impacts(
                mult_layer_impacts, cur_baseline, method=self.linearization_method
            )
            linearized_mult_impacts_arr.append(linearized_mult_layer_impacts)
            cur_baseline *= mult_layer_impacts.prod(axis=2)
        linearized_mult_impacts = np.concatenate(linearized_mult_impacts_arr, axis=2)
        predicted = to_numpy(intermediary.yhat).astype(np.float32)
        true_sales = intermediary.inputs.true_sales if intermediary.inputs is not None else intermediary.y_true
        true_sales = to_numpy(true_sales).astype(np.float32) if true_sales is not None else None
        smoothed_sales = (
            to_numpy(intermediary.y_smooth).astype(np.float32) if intermediary.y_smooth is not None else None
        )
        correction_feature_labels = []
        correction_feature_impacts = []
        if getattr(intermediary.inputs, "feature_masks", None) is not None:
            any_feature_mask = to_numpy(intermediary.inputs.feature_masks).all(2)
            if true_sales is not None:
                if smoothed_sales is not None:
                    smoothed_sales[~any_feature_mask] = true_sales[~any_feature_mask]
                model_error = true_sales - predicted
                predicted[~any_feature_mask] = true_sales[~any_feature_mask]
            else:
                model_error = np.zeros_like(predicted)
            remaining_error = model_error
            for label, i in self.encodings["feature_mask"].items():
                mask = to_numpy(intermediary.inputs.feature_masks[:, :, i])
                impact = remaining_error * ~mask
                remaining_error = remaining_error * mask
                correction_feature_labels.append(label)
                correction_feature_impacts.append(impact)
        num_time = baseline.shape[1]
        granularity_names, level_names = get_granularity_names(self.encodings, intermediary.inputs)
        if intermediary.inputs is not None:
            date_lookup = get_lookups(self.encodings["date"])
            index = pd.DatetimeIndex([date_lookup[i] for i in to_numpy(intermediary.inputs.date_index)], name="time")
            no_prediction_mask = intermediary.inputs.no_prediction_mask
        else:
            index = None
            no_prediction_mask = None

        column_names = (*level_names, "group", "signal") if self.with_groups else (*level_names, "signal")

        granular_impacts_matrix = get_granular_total_impacts_matrix(
            baseline,
            linearized_mult_impacts,
            additive_impacts,
            correction_feature_impacts,
            self.apply_sales_mask,
            to_numpy(no_prediction_mask) if no_prediction_mask is not None else None,
        )

        impacts_df = self.get_granular_total_impacts_df(
            intermediary,
            granular_impacts_matrix,
            multiplicative_sublabels,
            additive_sublabels,
            correction_feature_labels,
            num_time,
            granularity_names,
            level_names,
            index,
            column_names,
        )

        step_df = pd.DataFrame(
            np.full((num_time, len(granularity_names)), fill_value=1.0) * to_numpy(intermediary.step),
            columns=pd.MultiIndex.from_tuples(
                [
                    (*granularity, "", STEP_COL_NAME) if self.with_groups else (*granularity, STEP_COL_NAME)
                    for granularity in granularity_names
                ],
                names=column_names,
            ),
            index=index,
        )

        revenue_df = self.get_revenue_df(
            intermediary,
            granularity_names,
            index,
            column_names,
            predicted,
            smoothed_sales,
        )

        if self.separate_dfs:
            return impacts_df, step_df, revenue_df
        return pd.concat([impacts_df, step_df, revenue_df], axis=1).sort_index()

    def _generate_signal_mappings_df(
        self, intermediary: EconomicIntermediaries, collapse_lead_lag: bool = True
    ) -> pd.DataFrame:
        all_signal_mappings = {
            f"signal_{level}": map_signals(
                self.encodings, intermediary, collapse_level=level, collapse_lead_lags=collapse_lead_lag
            )
            for level in range(MAX_COLLAPSE_LEVEL + 1)
        }

        signal_mappings_df = pd.DataFrame.from_dict(all_signal_mappings, orient="index").T
        return signal_mappings_df

    def _customize_signal_mappings(self, signal_mappings_df: pd.DataFrame) -> dict[str, str]:
        """Appropriately customize the signal_mappings that the model is trained on to custom values specified by
        1. `decompose_level` for top-down decomposition of specified levels one level downward
        2. `show_specific` for bottom-up propagation of the values to all the levels upwards
        """
        collapse_level = self.collapse_level if self.collapse_level <= MAX_COLLAPSE_LEVEL else MAX_COLLAPSE_LEVEL

        # E.g - Price Ratio at collapse level 1 is remapped to Price Ratio at collapse level 2
        # E.g - Christmas Day at collapse level 0 is remapped to Christmas Day at collapse level 1 & 2
        if self.show_specific:
            # TODO: @Stalin fix this to work if `collapse_level=CollapseDict`
            for level in range(collapse_level):
                spec_mask = signal_mappings_df[f"signal_{level}"].isin(self.show_specific)
                curr_level = level + 1
                while curr_level <= collapse_level:
                    signal_mappings_df.loc[spec_mask, f"signal_{curr_level}"] = signal_mappings_df.loc[
                        spec_mask, f"signal_{level}"
                    ]
                    curr_level += 1

        # E.g - Holidays / Seasonality at collapse level 2 is remapped to Holiday & Time of year from collapse level 1
        if self.decompose_level and collapse_level >= 1:
            decomp_mask = signal_mappings_df[f"signal_{collapse_level}"].isin(self.decompose_level)
            signal_mappings_df.loc[decomp_mask, f"signal_{collapse_level}"] = signal_mappings_df.loc[
                decomp_mask, f"signal_{collapse_level-1}"
            ]
        signal_mappings = signal_mappings_df[f"signal_{collapse_level}"].to_dict()
        return signal_mappings

    def get_granular_total_impacts_df(
        self,
        intermediary: EconomicIntermediaries,
        granular_impacts: NDArray,
        multiplicative_sublabels: list[list[str]],
        additive_sublabels: list[list[str]],
        correction_sublabels: list[str],
        num_time: int,
        granularity_names: list[list[str]],
        level_names: list[str],
        index: pd.DatetimeIndex,
        column_names: tuple[str],
    ) -> pd.DataFrame:
        labels = ["baseline"] + list(
            itertools.chain(*multiplicative_sublabels, *additive_sublabels, correction_sublabels)
        )
        columns = pd.MultiIndex.from_tuples(
            [(*granularity, label) for granularity in granularity_names for label in labels],
            names=(*level_names, "signal"),
        )
        impacts_df = pd.DataFrame(
            np.transpose(granular_impacts, [1, 0, 2]).reshape(num_time, -1) * self.encodings["normalization_factor"],
            columns=columns,
            index=index,
        )
        collapse_lead_lag = self.collapse_lead_lag or self.collapse_level > 0

        if isinstance(self.collapse_level, CollapseDict):
            signal_mappings = map_signals(
                self.encodings, intermediary, collapse_level=self.collapse_level, collapse_lead_lags=collapse_lead_lag
            )
            if self.show_specific:
                logger.warning(f"{self.show_specific} won't be used as a dictionary is passed in collapse_level")
            if self.decompose_level:
                logger.warning(f"{self.decompose_level} won't be used as a dictionary is passed in collapse_level")
        else:
            signal_mappings_df = self._generate_signal_mappings_df(intermediary, collapse_lead_lag=collapse_lead_lag)
            signal_mappings = self._customize_signal_mappings(signal_mappings_df)

        orderings = {}
        for signal in impacts_df.columns.get_level_values("signal"):
            if signal_mappings[signal] not in orderings:
                orderings[signal_mappings[signal]] = len(orderings)

        impacts_df = (
            impacts_df.groupby(
                by=[*level_names, impacts_df.columns.get_level_values("signal").map(signal_mappings)], axis=1
            )
            .sum()
            .sort_index(axis=1, key=lambda x: x.map(orderings) if x.name == "signal" else x)
        )

        if self.with_groups:
            impacts_df.columns = self.get_grouped_columns(
                intermediary,
                impacts_df.columns,
                column_names,
                signal_mappings,
            )

        return impacts_df

    def get_grouped_columns(
        self,
        intermediary: EconomicIntermediaries,
        impact_df_columns: list[tuple[str]],
        column_names: tuple[str],
        signal_mappings: dict,
    ) -> pd.MultiIndex:
        grouped_signal_mappings = DefaultIdentityDict(
            map_signals(self.encodings, intermediary, collapse_level=self.collapse_level + 1, collapse_lead_lags=True)
        )
        # org signals are changed via signal_mappings. getting the groups on higher collapse level
        grouped_signal_mappings = DefaultIdentityDict(
            grouped_signal_mappings
            if self.collapse_level == 0 and self.collapse_lead_lag is False
            else {mapping: grouped_signal_mappings[signal] for signal, mapping in signal_mappings.items()}
        )
        grouped_column_index = pd.MultiIndex.from_tuples(
            [(*col[:-1], grouped_signal_mappings[col[-1]], col[-1]) for col in impact_df_columns],
            names=column_names,
        )

        return grouped_column_index

    def get_revenue_df(
        self,
        intermediary: EconomicIntermediaries,
        granularity_names: list[list[str]],
        index: pd.DatetimeIndex,
        column_names: tuple[str],
        predicted: np.ndarray,
        smoothed_sales: np.ndarray | None,
    ) -> pd.DataFrame:
        no_prediction_mask = getattr(intermediary.inputs, "no_prediction_mask", None)
        revenue_matrix = get_revenue_matrix(
            predicted,
            smoothed_sales,
            self.encodings["normalization_factor"],
            self.apply_sales_mask,
            to_numpy(no_prediction_mask) if no_prediction_mask is not None else np.ones_like(predicted),
        )

        cols = pd.MultiIndex.from_tuples(
            [
                (*granularity, "", signal_type) if self.with_groups else (*granularity, signal_type)
                for granularity in granularity_names
                for signal_type in ["Predicted Sales"] + (["Smoothed Sales"] if smoothed_sales is not None else [])
            ],
            names=column_names,
        )

        revenue_df = pd.DataFrame(revenue_matrix, columns=cols, index=index)

        if intermediary.inputs is not None and intermediary.inputs.true_sales is not None:
            original_sales_df = (
                get_sales_df(intermediary.inputs, self.encodings) * self.encodings["normalization_factor"]
            )
            original_sales_df.columns = pd.MultiIndex.from_tuples(
                [
                    (*existing, "", REV_NAME) if self.with_groups else (*existing, REV_NAME)
                    for existing in original_sales_df.columns
                ],
                names=(
                    [*original_sales_df.columns.names, "group", "signal"]
                    if self.with_groups
                    else [*original_sales_df.columns.names, "signal"]
                ),
            )

            revenue_df = pd.concat(
                [
                    revenue_df,
                    original_sales_df.reorder_levels(revenue_df.columns.names, axis=1),
                ],
                axis=1,
            )

        return revenue_df.sort_index(axis=1)

    def make_impact_plots(
        self,
        all_impacts_df: pd.DataFrame,
        granularity_names: list[str],
        height: int | None,
        group_click: Literal["togglegroup", "toggleitem"],
        legend_groups: dict,
        extra_titles_map: dict | None = None,
        animation_frame: str | None = None,
        markers: bool = True,
    ):
        all_plots = {}
        for granularity_name in granularity_names:
            modified_title = prepare_modified_title(granularity_name, extra_titles_map)
            impacts_df = all_impacts_df.xs(granularity_name, axis=1, level="granularity")

            signal_names = [col for col in impacts_df.columns if col not in [*LINE_NAMES, STEP_COL_NAME]]
            range_y = get_range(impacts_df, LINE_NAMES, signal_names)
            impacts_df.loc[:, LINE_NAMES].replace(0, np.nan, inplace=True)

            impact_plot = px.bar(
                impacts_df,
                x=impacts_df.index,
                y=signal_names,
                animation_frame=animation_frame,
                animation_group=None if animation_frame is None else impacts_df.index,
                range_y=range_y,
                height=height,
                title=modified_title,
                labels=signal_names,
            )

            update_impact_plot_legend(impact_plot, legend_groups, group_click)

            pred_plot = px.line(
                impacts_df,
                x=impacts_df.index,
                y=LINE_NAMES,
                animation_frame=animation_frame,
                animation_group=None if animation_frame is None else impacts_df.index,
                range_y=range_y,
                height=height,
                title=modified_title,
                labels=LINE_NAMES,
                markers=markers,
            )

            plot = merge_plots_no_layout(impact_plot, pred_plot)
            all_plots[granularity_name] = plot
        return all_plots

    @cached_property
    def df(
        self,
    ) -> (
        pd.DataFrame
        | list[pd.DataFrame]
        | tuple[list[pd.DataFrame], list[pd.DataFrame], list[pd.DataFrame]]
        | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
    ):
        """Get the impacts dataframe"""
        if self.final_df is not None:
            return self.final_df
        if self.separate_dfs:
            all_impacts: list[pd.DataFrame] = []
            all_steps: list[pd.DataFrame] = []
            all_ys: list[pd.DataFrame] = []
            for step_intermediary in self.intermediaries:
                step_impacts_df, step_df, step_y_df = self.get_impacts_df(step_intermediary)
                if TYPE_CHECKING:
                    assert isinstance(step_impacts_df, pd.DataFrame)
                    assert isinstance(step_df, pd.DataFrame)
                    assert isinstance(step_y_df, pd.DataFrame)

                if self.combine_granularities_flag:
                    step_impacts_df = combine_granularities(step_impacts_df)
                    step_df = combine_granularities(step_df)
                    step_y_df = combine_granularities(step_y_df)
                all_impacts.append(step_impacts_df)
                all_steps.append(step_df)
                all_ys.append(step_y_df)

            if len(self.intermediaries) == 1:
                return all_impacts[0], all_steps[0], all_ys[0]
            else:
                return all_impacts, all_steps, all_ys
        else:
            impact_dfs = []
            step = None
            for step_intermediary in self.intermediaries:
                if step == to_numpy(step_intermediary.step):
                    continue
                if self.is_animation_call:
                    step = to_numpy(step_intermediary.step)
                step_impacts_df = self.get_impacts_df(step_intermediary)
                if TYPE_CHECKING:
                    assert isinstance(step_impacts_df, pd.DataFrame)
                impact_dfs.append(
                    combine_granularities(step_impacts_df) if self.combine_granularities_flag else step_impacts_df
                )
            impact_dfs = impact_dfs[0] if len(self.intermediaries) == 1 else impact_dfs

            return impact_dfs

    @cached_property
    def visualization_df(self):
        if not self.separate_dfs:
            return self.df

        all_impacts_df, all_step_df, all_revenue_df = self.df
        if TYPE_CHECKING:
            assert isinstance(all_impacts_df, list)
            assert isinstance(all_step_df, list)
            assert isinstance(all_revenue_df, list)

        if isinstance(all_impacts_df, list):
            all_impacts_df = [
                pd.concat([step_impact_df, step_step_df, step_revenue_df], axis=1).sort_index()
                for step_impact_df, step_step_df, step_revenue_df in zip(all_impacts_df, all_step_df, all_revenue_df)
            ]

        return all_impacts_df

    def visualize(
        self,
        wibbles: list[str] | None = None,
        wibble_encodings: dict | None = None,
        height: int | None = 800,
        group_click: Literal["togglegroup", "toggleitem"] = "togglegroup",
        extra_titles_map: dict | None = None,
        markers: bool = True,
    ) -> dict[str, go.Figure]:
        """Visualise the impacts and sales

        Args:
            wibbles (list[str] | None, optional): Names of all the wibbles/keys. Defaults to None.
            wibble_encodings (dict | None, optional): Wibble encodings . Defaults to None.
            height (int, optional): The height of the bars in the bar chart. Defaults to 800.
            show_plots (bool, optional): Flag to indicate whether to show the plots. Defaults to True.
            group_click (str): A string value that determines how clicking on a legend group affects the visibility of
                                traces associated with that group. Defaults to "togglegroup".
            extra_titles_map (dict): Mapping of extra titles. Defaults to None.
            markers (bool, optional): Flag to specify the appearance of markers. Defaults to True.

        Returns:
            dict[str, go.Figure]: Prepared bar and line charts of impacts as well as revenues respectively of all
            granularities
        """
        all_impacts_df = self.visualization_df

        # Prepare animation frame and all_impacts_df
        animation_frame = STEP_COL_NAME if self.is_animation_call else None
        if isinstance(all_impacts_df, list):
            if self.is_animation_call:
                all_impacts_df = pd.concat(all_impacts_df, axis=0)
            else:
                all_impacts_df = pd.concat(all_impacts_df, axis=1)

        legend_groups = prepare_legend_groups(all_impacts_df)
        all_impacts_df = drop_no_impact(all_impacts_df)

        if wibbles is None:
            if wibble_encodings:
                wibbles = list(wibble_encodings.keys())
            else:
                wibbles = all_impacts_df.columns.unique("granularity").tolist()

        all_plots = self.make_impact_plots(
            all_impacts_df,
            wibbles,
            height,
            group_click,
            legend_groups,
            extra_titles_map,
            animation_frame,
            markers,
        )

        return all_plots

    def visualize_impacts_change(
        self,
        height: int = 800,
        show_plots: bool = False,
        offset: float = EPSILON,
    ) -> go.Figure:
        """Visualise difference in latest impact from the others

        Args:
            height (int, optional): The height of the bars in the bar chart. Defaults to 800.
            show_plots (bool, optional): Flag to indicate whether to show the plots. Defaults to True.
            offset (float, optional): A small numeric value. Defaults to EPSILON.

        Returns:
            go.Figure: Plot of change in impacts
        """

        if self.separate_dfs:
            all_impacts_df = self.df[0]
        else:
            all_impacts_df = self.df

        if isinstance(all_impacts_df, list):
            first_step_impacts = all_impacts_df[0]
        else:
            first_step_impacts = all_impacts_df
            all_impacts_df = [all_impacts_df]

        granularity_names = first_step_impacts.columns.get_level_values("granularity").unique()
        col_to_drop = [(granularity, STEP_COL_NAME) for granularity in granularity_names] + [
            (granularity, REV_NAME) for granularity in granularity_names
        ]

        impacts_np = np.stack(
            [df.drop(columns=col_to_drop).sort_index(axis=1).values.reshape(-1) for df in all_impacts_df], axis=0
        )

        i_s = list(range(1, impacts_np.shape[0]))
        last_impact = impacts_np[-1]
        abs_diffs = np.abs(last_impact - impacts_np[:-1])
        rel_diffs = 2 * abs_diffs / (np.abs(last_impact) + np.abs(impacts_np[:-1]) + offset)
        diffs_np = np.stack([abs_diffs.max(1), rel_diffs.max(1)], axis=1)[::-1]
        diffs_df = pd.DataFrame(diffs_np, index=i_s, columns=["abs", "rel"])
        plot = diffs_df.plot(backend="plotly", height=height)
        return plot

df: pd.DataFrame | list[pd.DataFrame] | tuple[list[pd.DataFrame], list[pd.DataFrame], list[pd.DataFrame]] | tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame] cached property

Get the impacts dataframe

visualize(wibbles=None, wibble_encodings=None, height=800, group_click='togglegroup', extra_titles_map=None, markers=True)

Visualise the impacts and sales

Parameters:

Name Type Description Default
wibbles list[str] | None

Names of all the wibbles/keys. Defaults to None.

None
wibble_encodings dict | None

Wibble encodings . Defaults to None.

None
height int

The height of the bars in the bar chart. Defaults to 800.

800
show_plots bool

Flag to indicate whether to show the plots. Defaults to True.

required
group_click str

A string value that determines how clicking on a legend group affects the visibility of traces associated with that group. Defaults to "togglegroup".

'togglegroup'
extra_titles_map dict

Mapping of extra titles. Defaults to None.

None
markers bool

Flag to specify the appearance of markers. Defaults to True.

True

Returns:

Type Description
dict[str, Figure]

dict[str, go.Figure]: Prepared bar and line charts of impacts as well as revenues respectively of all

dict[str, Figure]

granularities

Source code in wt_ml/output/output_impact.py
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
def visualize(
    self,
    wibbles: list[str] | None = None,
    wibble_encodings: dict | None = None,
    height: int | None = 800,
    group_click: Literal["togglegroup", "toggleitem"] = "togglegroup",
    extra_titles_map: dict | None = None,
    markers: bool = True,
) -> dict[str, go.Figure]:
    """Visualise the impacts and sales

    Args:
        wibbles (list[str] | None, optional): Names of all the wibbles/keys. Defaults to None.
        wibble_encodings (dict | None, optional): Wibble encodings . Defaults to None.
        height (int, optional): The height of the bars in the bar chart. Defaults to 800.
        show_plots (bool, optional): Flag to indicate whether to show the plots. Defaults to True.
        group_click (str): A string value that determines how clicking on a legend group affects the visibility of
                            traces associated with that group. Defaults to "togglegroup".
        extra_titles_map (dict): Mapping of extra titles. Defaults to None.
        markers (bool, optional): Flag to specify the appearance of markers. Defaults to True.

    Returns:
        dict[str, go.Figure]: Prepared bar and line charts of impacts as well as revenues respectively of all
        granularities
    """
    all_impacts_df = self.visualization_df

    # Prepare animation frame and all_impacts_df
    animation_frame = STEP_COL_NAME if self.is_animation_call else None
    if isinstance(all_impacts_df, list):
        if self.is_animation_call:
            all_impacts_df = pd.concat(all_impacts_df, axis=0)
        else:
            all_impacts_df = pd.concat(all_impacts_df, axis=1)

    legend_groups = prepare_legend_groups(all_impacts_df)
    all_impacts_df = drop_no_impact(all_impacts_df)

    if wibbles is None:
        if wibble_encodings:
            wibbles = list(wibble_encodings.keys())
        else:
            wibbles = all_impacts_df.columns.unique("granularity").tolist()

    all_plots = self.make_impact_plots(
        all_impacts_df,
        wibbles,
        height,
        group_click,
        legend_groups,
        extra_titles_map,
        animation_frame,
        markers,
    )

    return all_plots

visualize_impacts_change(height=800, show_plots=False, offset=EPSILON)

Visualise difference in latest impact from the others

Parameters:

Name Type Description Default
height int

The height of the bars in the bar chart. Defaults to 800.

800
show_plots bool

Flag to indicate whether to show the plots. Defaults to True.

False
offset float

A small numeric value. Defaults to EPSILON.

EPSILON

Returns:

Type Description
Figure

go.Figure: Plot of change in impacts

Source code in wt_ml/output/output_impact.py
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
def visualize_impacts_change(
    self,
    height: int = 800,
    show_plots: bool = False,
    offset: float = EPSILON,
) -> go.Figure:
    """Visualise difference in latest impact from the others

    Args:
        height (int, optional): The height of the bars in the bar chart. Defaults to 800.
        show_plots (bool, optional): Flag to indicate whether to show the plots. Defaults to True.
        offset (float, optional): A small numeric value. Defaults to EPSILON.

    Returns:
        go.Figure: Plot of change in impacts
    """

    if self.separate_dfs:
        all_impacts_df = self.df[0]
    else:
        all_impacts_df = self.df

    if isinstance(all_impacts_df, list):
        first_step_impacts = all_impacts_df[0]
    else:
        first_step_impacts = all_impacts_df
        all_impacts_df = [all_impacts_df]

    granularity_names = first_step_impacts.columns.get_level_values("granularity").unique()
    col_to_drop = [(granularity, STEP_COL_NAME) for granularity in granularity_names] + [
        (granularity, REV_NAME) for granularity in granularity_names
    ]

    impacts_np = np.stack(
        [df.drop(columns=col_to_drop).sort_index(axis=1).values.reshape(-1) for df in all_impacts_df], axis=0
    )

    i_s = list(range(1, impacts_np.shape[0]))
    last_impact = impacts_np[-1]
    abs_diffs = np.abs(last_impact - impacts_np[:-1])
    rel_diffs = 2 * abs_diffs / (np.abs(last_impact) + np.abs(impacts_np[:-1]) + offset)
    diffs_np = np.stack([abs_diffs.max(1), rel_diffs.max(1)], axis=1)[::-1]
    diffs_df = pd.DataFrame(diffs_np, index=i_s, columns=["abs", "rel"])
    plot = diffs_df.plot(backend="plotly", height=height)
    return plot