col_width(table, col_margin=1)

Calculates the column widths for padding. It calculates the max string length for each column so it is diplayed properly. table (Dict[str, List]): {"header": [], "rows": [[]]}

Source code in wt_ml/module/module_utils.py
11
12
13
14
15
16
17
18
19
def col_width(table, col_margin=1):
    """
    Calculates the column widths for padding.
    It calculates the max string length for each column so it is diplayed properly.
    table (Dict[str, List]): {"header": [], "rows": [[]]}
    """
    rows = [table["header"]] + table["rows"]
    col_widths = np.array([[len(text) + col_margin for text in row] for row in rows])
    return col_widths.max(0).tolist()

max_len_update(max_len, col_widths)

Calculates the max width of the tables, so that we can print separators. max_len (int): The current max_len value. col_widths (List[int]): The output of col_widths function

Source code in wt_ml/module/module_utils.py
22
23
24
25
26
27
28
def max_len_update(max_len, col_widths):
    """
    Calculates the max width of the tables, so that we can print separators.
    max_len (int): The current max_len value.
    col_widths (List[int]): The output of col_widths function
    """
    return max(max_len, sum(col_widths))

print_summary(initial, current, level=0, detailed=False, tab_size=2, col_margin=1, parent=None)

Recursively prints out summary of each Module. initial (Module): The module the summary function was called from. When printed out, it'll be signified by a '*'. current (Module): The module which will be printed out. level (int): Signifies how deep the module is. Should not be passed manually. detailed (bool): Show detailed summary. Defaults to False. tab_size (int): Size of indent, for each layer. If set to 0, will display as a table. Defaults to 2. col_margin (int): Margin between each columns. Used in detailed summary. Defaults to 4. parent (Module): The parent of the current module.

Source code in wt_ml/module/module_utils.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def print_summary(
    initial: "Module", current: "Module", level=0, detailed=False, tab_size=2, col_margin=1, parent=None
) -> list[str]:
    """
    Recursively prints out summary of each Module.
    initial (Module): The module the summary function was called from. When printed out, it'll be signified by a '*'.
    current (Module): The module which will be printed out.
    level (int): Signifies how deep the module is. Should not be passed manually.
    detailed (bool): Show detailed summary. Defaults to False.
    tab_size (int): Size of indent, for each layer. If set to 0, will display as a table. Defaults to 2.
    col_margin (int): Margin between each columns. Used in detailed summary. Defaults to 4.
    parent (Module): The parent of the `current` module.
    """
    result = []
    indent = " " * (tab_size * level)
    # signifies if the current module is the same as the one summary was called from.
    this = "*" if initial == current else ""

    _local_trn_vars = current._local_trn_vars
    _local_non_trn_vars = current._local_non_trn_vars
    _local_all_variables = current._local_all_variables
    trn_vars = current.trn_vars
    non_trn_vars = current.non_trn_vars
    all_variables = current.all_variables
    summary = {
        "total self-owned variables": len(_local_all_variables),
        "trainable self-owned variables": len(_local_trn_vars),
        "non-trainable self-owned variables": len(_local_non_trn_vars),
        "total self-owned params": np.sum([np.prod(var.get_shape()) for var in _local_all_variables]),
        "trainable self-owned params": np.sum([np.prod(var.get_shape()) for var in _local_trn_vars]),
        "non-trainable self-owned params": np.sum([np.prod(var.get_shape()) for var in _local_non_trn_vars]),
        "total variables": len(all_variables),
        "trainable variables": len(trn_vars),
        "non-trainable variables": len(non_trn_vars),
        "total params": np.sum([np.prod(var.get_shape()) for var in all_variables]),
        "trainable params": np.sum([np.prod(var.get_shape()) for var in trn_vars]),
        "non-trainable params": np.sum([np.prod(var.get_shape()) for var in non_trn_vars]),
    }
    if detailed:
        table = []
        table.append(
            {
                "header": [
                    f"{this}{current.name} ({type(current).__qualname__}) <{type(current)}>",
                ],
                "rows": [],
                "col_width": [],
            }
        )
        if tab_size == 0 and parent:
            # if tab_size, it will be nested so no need to print this
            table[0]["rows"].append([f"Parent: {parent}"])
        # col_width has to be calculated for each table
        table[0]["col_width"] = col_width(table[0], col_margin)
        # max_len should always be updated
        max_len = max_len_update(0, table[0]["col_width"])

        if len(_local_all_variables):
            table1 = {"rows": []}
            table1["header"] = ("Variable", "Shape", "dtype", "Trainable", "Params")
            total_params = 0
            for var in _local_all_variables:
                params = np.prod(var.get_shape())
                table1["rows"].append((var.name, str(var.shape), var.dtype.name, str(var.trainable), str(params)))
                total_params += params
            summary["total self-owned params"] = total_params
            table1["col_width"] = col_width(table1, col_margin)
            table.append(table1)
            max_len = max_len_update(max_len, table[-1]["col_width"])

        # now we print
        result.append(f"{indent}{'-' * max_len}")
        for t in table:
            # using ljust and col_widths, we can ensure the table is displayed properly
            header = "".join(th.ljust(t["col_width"][i]) for i, th in enumerate(t["header"]))
            result.append(f"{indent}{header}")
            result.append(f"{indent}{'=' * max_len}")  # header separator
            if t["rows"]:
                for row in t["rows"]:
                    rows = "".join(tr.ljust(t["col_width"][i]) for i, tr in enumerate(row))
                    result.append(f"{indent}{rows}")
            result.append(f"{indent}{'-'*max_len}")  # table separator
        for desc, num in summary.items():
            result.append(f"{indent}{num} {desc}")  # variable summary

    else:
        result.append(f"{indent}{this}{current.name} ({current.__class__.__name__})")
        mini_summary = ", ".join(f"{num} {desc}" for desc, num in summary.items() if num)
        if len(mini_summary):
            result.append(f"{indent}  -> {mini_summary}")

    for child in current._local_children:
        # recursive function, increases each level
        result += print_summary(
            initial, child, parent=current, level=level + 1, detailed=detailed, tab_size=tab_size, col_margin=col_margin
        )
    return result