31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127 | def print_summary(
initial: "Module", current: "Module", level=0, detailed=False, tab_size=2, col_margin=1, parent=None
) -> list[str]:
"""
Recursively prints out summary of each Module.
initial (Module): The module the summary function was called from. When printed out, it'll be signified by a '*'.
current (Module): The module which will be printed out.
level (int): Signifies how deep the module is. Should not be passed manually.
detailed (bool): Show detailed summary. Defaults to False.
tab_size (int): Size of indent, for each layer. If set to 0, will display as a table. Defaults to 2.
col_margin (int): Margin between each columns. Used in detailed summary. Defaults to 4.
parent (Module): The parent of the `current` module.
"""
result = []
indent = " " * (tab_size * level)
# signifies if the current module is the same as the one summary was called from.
this = "*" if initial == current else ""
_local_trn_vars = current._local_trn_vars
_local_non_trn_vars = current._local_non_trn_vars
_local_all_variables = current._local_all_variables
trn_vars = current.trn_vars
non_trn_vars = current.non_trn_vars
all_variables = current.all_variables
summary = {
"total self-owned variables": len(_local_all_variables),
"trainable self-owned variables": len(_local_trn_vars),
"non-trainable self-owned variables": len(_local_non_trn_vars),
"total self-owned params": np.sum([np.prod(var.get_shape()) for var in _local_all_variables]),
"trainable self-owned params": np.sum([np.prod(var.get_shape()) for var in _local_trn_vars]),
"non-trainable self-owned params": np.sum([np.prod(var.get_shape()) for var in _local_non_trn_vars]),
"total variables": len(all_variables),
"trainable variables": len(trn_vars),
"non-trainable variables": len(non_trn_vars),
"total params": np.sum([np.prod(var.get_shape()) for var in all_variables]),
"trainable params": np.sum([np.prod(var.get_shape()) for var in trn_vars]),
"non-trainable params": np.sum([np.prod(var.get_shape()) for var in non_trn_vars]),
}
if detailed:
table = []
table.append(
{
"header": [
f"{this}{current.name} ({type(current).__qualname__}) <{type(current)}>",
],
"rows": [],
"col_width": [],
}
)
if tab_size == 0 and parent:
# if tab_size, it will be nested so no need to print this
table[0]["rows"].append([f"Parent: {parent}"])
# col_width has to be calculated for each table
table[0]["col_width"] = col_width(table[0], col_margin)
# max_len should always be updated
max_len = max_len_update(0, table[0]["col_width"])
if len(_local_all_variables):
table1 = {"rows": []}
table1["header"] = ("Variable", "Shape", "dtype", "Trainable", "Params")
total_params = 0
for var in _local_all_variables:
params = np.prod(var.get_shape())
table1["rows"].append((var.name, str(var.shape), var.dtype.name, str(var.trainable), str(params)))
total_params += params
summary["total self-owned params"] = total_params
table1["col_width"] = col_width(table1, col_margin)
table.append(table1)
max_len = max_len_update(max_len, table[-1]["col_width"])
# now we print
result.append(f"{indent}{'-' * max_len}")
for t in table:
# using ljust and col_widths, we can ensure the table is displayed properly
header = "".join(th.ljust(t["col_width"][i]) for i, th in enumerate(t["header"]))
result.append(f"{indent}{header}")
result.append(f"{indent}{'=' * max_len}") # header separator
if t["rows"]:
for row in t["rows"]:
rows = "".join(tr.ljust(t["col_width"][i]) for i, tr in enumerate(row))
result.append(f"{indent}{rows}")
result.append(f"{indent}{'-'*max_len}") # table separator
for desc, num in summary.items():
result.append(f"{indent}{num} {desc}") # variable summary
else:
result.append(f"{indent}{this}{current.name} ({current.__class__.__name__})")
mini_summary = ", ".join(f"{num} {desc}" for desc, num in summary.items() if num)
if len(mini_summary):
result.append(f"{indent} -> {mini_summary}")
for child in current._local_children:
# recursive function, increases each level
result += print_summary(
initial, child, parent=current, level=level + 1, detailed=detailed, tab_size=tab_size, col_margin=col_margin
)
return result
|