OutputConvergence

Bases: ModelOutputItem

Source code in wt_ml/output/output_convergence.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class OutputConvergence(ModelOutputItem):
    def __init__(
        self, convergence_intermediaries: list[dict] | None = None, final_df: dict[str, pd.DataFrame] | None = None
    ):
        super().__init__(final_df, [convergence_intermediaries])
        self.final_df = final_df
        if self.final_df is None:
            self.convergence_intermediaries = convergence_intermediaries

    @cached_property
    def df(self) -> dict[pd.DataFrame]:
        if self.final_df is not None:
            return self.final_df

        results = {
            intermediary["step"]: {k: v for k, v in intermediary.items() if k != "step"}
            for intermediary in self.convergence_intermediaries
        }
        impact_df, other_df = convert_results_to_df(results)
        stability_per_param, stability_per_layer = calculate_layer_convergence_stability(impact_df)
        return {
            "impacts_by_signal_df": impact_df,
            "other_df": other_df,
            "stability_per_param": stability_per_param,
            "stability_per_layer": stability_per_layer,
        }

    def visualize(self, return_figure_flag: bool = False):
        """Visualize convergence plots of layers

        Args:
            show_plots (bool, optional): Flag to indicate whether to show the plots. Defaults to True.
            file_name (str, optional): Name of the file to save the plots. Defaults to "intermediary_histograms".
            plot_title (str, optional): Title of the plot. Defaults to "Intermediaries - Histograms".
            output_dir (Path | str, optional): Path to save the plots. Defaults to None.
        Returns:
            dict[str, go.Figure]: Prepared line charts of convergence plots of layers.
        """
        impact_df = self.df["impacts_by_signal_df"]
        other_df = self.df["other_df"]
        stability_per_param = self.df["stability_per_param"]
        stability_per_layer = self.df["stability_per_layer"]

        other_fig = make_plot(other_df)
        impact_fig = make_plot_with_dropdown(impact_df)
        stability_layer_fig = plot_convergence_stability(stability_per_layer)
        stability_param_fig = plot_convergence_stability_per_param(stability_per_param)

        fig_htmls = [
            pio.to_html(fig, full_html=False, include_plotlyjs="cdn" if i == 0 else False)
            for i, fig in enumerate([other_fig, impact_fig, stability_layer_fig, stability_param_fig])
        ]
        combined_html: str = get_combined_html(fig_htmls)

        if return_figure_flag:
            return {
                "other_fig": other_fig,
                "impact_fig": impact_fig,
                "stability_layer_fig": stability_layer_fig,
                "stability_param_fig": stability_param_fig,
            }

        return combined_html

visualize(return_figure_flag=False)

Visualize convergence plots of layers

Parameters:

Name Type Description Default
show_plots bool

Flag to indicate whether to show the plots. Defaults to True.

required
file_name str

Name of the file to save the plots. Defaults to "intermediary_histograms".

required
plot_title str

Title of the plot. Defaults to "Intermediaries - Histograms".

required
output_dir Path | str

Path to save the plots. Defaults to None.

required

Returns: dict[str, go.Figure]: Prepared line charts of convergence plots of layers.

Source code in wt_ml/output/output_convergence.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def visualize(self, return_figure_flag: bool = False):
    """Visualize convergence plots of layers

    Args:
        show_plots (bool, optional): Flag to indicate whether to show the plots. Defaults to True.
        file_name (str, optional): Name of the file to save the plots. Defaults to "intermediary_histograms".
        plot_title (str, optional): Title of the plot. Defaults to "Intermediaries - Histograms".
        output_dir (Path | str, optional): Path to save the plots. Defaults to None.
    Returns:
        dict[str, go.Figure]: Prepared line charts of convergence plots of layers.
    """
    impact_df = self.df["impacts_by_signal_df"]
    other_df = self.df["other_df"]
    stability_per_param = self.df["stability_per_param"]
    stability_per_layer = self.df["stability_per_layer"]

    other_fig = make_plot(other_df)
    impact_fig = make_plot_with_dropdown(impact_df)
    stability_layer_fig = plot_convergence_stability(stability_per_layer)
    stability_param_fig = plot_convergence_stability_per_param(stability_per_param)

    fig_htmls = [
        pio.to_html(fig, full_html=False, include_plotlyjs="cdn" if i == 0 else False)
        for i, fig in enumerate([other_fig, impact_fig, stability_layer_fig, stability_param_fig])
    ]
    combined_html: str = get_combined_html(fig_htmls)

    if return_figure_flag:
        return {
            "other_fig": other_fig,
            "impact_fig": impact_fig,
            "stability_layer_fig": stability_layer_fig,
            "stability_param_fig": stability_param_fig,
        }

    return combined_html