OutputHistogram

Bases: ModelOutputItem

Source code in wt_ml/output/output_histograms.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class OutputHistogram(ModelOutputItem):
    def __init__(
        self, histogram_intermediaries: list[dict] | None = None, final_intermediaries: list[dict] | None = None
    ):
        super().__init__(final_intermediaries, [histogram_intermediaries])
        self.final_intermediaries = final_intermediaries
        if self.final_intermediaries is None:
            self.histogram_intermediaries = histogram_intermediaries

    def df(self):
        raise NotImplementedError("This method is not implemented for this class.")

    @cached_property
    def layers_data(self):
        if self.final_intermediaries is not None:
            return self.final_intermediaries
        all_ckpts = sorted([intermediary["step"] for intermediary in self.histogram_intermediaries])
        all_intermediaries_names = list(self.histogram_intermediaries[0].keys())
        all_intermediaries_names.remove("step")
        layer_groups = map_layer_to_intermediary_names(all_intermediaries_names)
        layer_hists_data_list = []
        for layer, intermediary_names in layer_groups.items():
            layer_hists_data = {
                intermediary["step"]: filter_intermediary_data(
                    {k: v for k, v in intermediary.items() if k != "step"}, intermediary_names
                )
                for intermediary in self.histogram_intermediaries
                if intermediary["step"] in all_ckpts
            } | {"layer": layer}
            layer_hists_data_list.append(layer_hists_data)

        return layer_hists_data_list

    def visualize(self):
        """Visualize layer param histograms

        Args:
            show_plots (bool, optional): Flag to indicate whether to show the plots. Defaults to True.
            file_name (str, optional): Name of the file to save the plots. Defaults to "intermediary_histograms".
            plot_title (str, optional): Title of the plot. Defaults to "Intermediaries - Histograms".
            output_dir (Path | str, optional): Path to save the plots. Defaults to None.
        Returns:
            dict[str, go.Figure]: Prepared histograms of layer params.
        """
        layer_hists_data_list = self.layers_data
        all_figs = {}
        for layer_hists_data in layer_hists_data_list:
            layer = layer_hists_data["layer"]
            layer_hists_data = {k: v for k, v in layer_hists_data.items() if k != "layer"}
            all_figs[layer] = create_layer_histograms(layer, layer_hists_data)
        return all_figs

visualize()

Visualize layer param histograms

Parameters:

Name Type Description Default
show_plots bool

Flag to indicate whether to show the plots. Defaults to True.

required
file_name str

Name of the file to save the plots. Defaults to "intermediary_histograms".

required
plot_title str

Title of the plot. Defaults to "Intermediaries - Histograms".

required
output_dir Path | str

Path to save the plots. Defaults to None.

required

Returns: dict[str, go.Figure]: Prepared histograms of layer params.

Source code in wt_ml/output/output_histograms.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def visualize(self):
    """Visualize layer param histograms

    Args:
        show_plots (bool, optional): Flag to indicate whether to show the plots. Defaults to True.
        file_name (str, optional): Name of the file to save the plots. Defaults to "intermediary_histograms".
        plot_title (str, optional): Title of the plot. Defaults to "Intermediaries - Histograms".
        output_dir (Path | str, optional): Path to save the plots. Defaults to None.
    Returns:
        dict[str, go.Figure]: Prepared histograms of layer params.
    """
    layer_hists_data_list = self.layers_data
    all_figs = {}
    for layer_hists_data in layer_hists_data_list:
        layer = layer_hists_data["layer"]
        layer_hists_data = {k: v for k, v in layer_hists_data.items() if k != "layer"}
        all_figs[layer] = create_layer_histograms(layer, layer_hists_data)
    return all_figs