get_all_hier_params(gt_model)

Get the learned weights for all HierarchicalEmbedding layers in a model

Source code in wt_ml/layers/hierarchical_utils.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def get_all_hier_params(gt_model: EconomicNetwork) -> dict:
    """Get the learned weights for all HierarchicalEmbedding layers in a model"""
    from wt_ml.layers.hier_embedding import HierchicalEmbedding

    hier_emb_objs = [m for m in gt_model.children if isinstance(m, HierchicalEmbedding)]
    all_hier_embs = {}
    for hier_obj in hier_emb_objs:
        hier_obj_name = "/".join(hier_obj.weights.name.split("/")[1:-1])
        hier_cat_wise_emb = get_hier_params_df(hier_obj)
        if hier_obj.use_bias:
            signals = list(hier_cat_wise_emb.values())[0].columns
            bias = pd.DataFrame(hier_obj.bias.numpy(), index=signals).T
            hier_cat_wise_emb["bias"] = bias
        all_hier_embs[hier_obj_name] = hier_cat_wise_emb
    return all_hier_embs

get_hier_emb_summaries(gt_model)

Get summary statistics for all HierarchicalEmbedding layers in a model

Source code in wt_ml/layers/hierarchical_utils.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def get_hier_emb_summaries(gt_model: EconomicNetwork) -> pd.DataFrame:
    """Get summary statistics for all HierarchicalEmbedding layers in a model"""
    all_summaries = []
    all_hier_embs = get_all_hier_params(gt_model)
    for hier_name, hier_cat_wise_emb in all_hier_embs.items():
        for hier_category, hier_embs in hier_cat_wise_emb.items():
            summary_df = hier_embs.describe()
            summary_df.loc["l2_norm"] = np.sqrt(np.square(hier_embs).sum(0))
            summary_df.index.name = "metric"
            summary_df.columns = pd.MultiIndex.from_product(
                [[hier_name], [hier_category], summary_df.columns],
                names=["layer", "hier_category", "signal"],
            )
            all_summaries.append(summary_df)
    all_summaries = pd.concat(all_summaries, axis=1).T
    return all_summaries

get_hier_params_df(hlayer)

Get the learned weights for a HierarchicalEmbedding layer

Source code in wt_ml/layers/hierarchical_utils.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def get_hier_params_df(hlayer: HierchicalEmbedding) -> dict[str, pd.DataFrame]:
    """Get the learned weights for a HierarchicalEmbedding layer"""
    n_features = hlayer.weights.shape[-1]
    feature_names = hlayer.feature_names
    if feature_names is None:
        feature_names = list(range(n_features))
    if len(feature_names) != n_features:
        if n_features % len(feature_names) == 0:
            num_dups = int(n_features / len(feature_names))
            feature_names = [f"{name}_{i+1}" for name in feature_names for i in range(num_dups)]
        else:
            raise ValueError(f"feature_names must be a list of size {n_features}, but got size {len(feature_names)}")
    output = dict()
    for col_names in hlayer.columns:
        if isinstance(col_names, str):
            # We want to assume col_names is a list of column names
            col_names = [col_names]
        num_cols = len(col_names)
        cat_cols = [col for col in col_names if hlayer.encodings[col] != "continuous"]
        cont_cols = [col for col in col_names if hlayer.encodings[col] == "continuous"]
        num_cat_cols = len(cat_cols)
        num_cont_cols = num_cols - num_cat_cols

        if num_cat_cols == 0:
            hierarchy = {cont_cols[0]: np.asarray([1.0])}
            output_index = [cont_cols[0]]
        else:
            midx = pd.MultiIndex.from_product([hlayer.encodings[c].values() for c in cat_cols], names=cat_cols)
            output_index = pd.MultiIndex.from_product([hlayer.encodings[c].keys() for c in cat_cols], names=cat_cols)
            if num_cont_cols > 0:
                midx = pd.concat({1.0: pd.DataFrame(index=midx)}, names=[cont_cols[0]]).index
            hierarchy = {h: midx.get_level_values(h).to_numpy() for h in midx.names}

        name = hlayer.stitched_cols(col_names)
        # The start of the region for this weight
        start = hlayer.offsets[name]
        if num_cont_cols == 0:
            shape = tf.shape(hierarchy[cat_cols[0]])
            weight = tf.ones(shape, dtype=tf.float32, name=f"{name}_weights")
        else:
            weight = hierarchy[cont_cols[0]]
        if num_cat_cols == 0:
            index = tf.cast(
                tf.fill(tf.shape(weight), hlayer.offsets[name]),
                dtype=tf.int64,
                name=f"{name}_indices",
            )
            # if no categorical columns, we have the value of continuous column as the index
        else:
            # The standard encoding of left to right indices given base col_counts[col] for each col
            offsets = np.cumprod([1] + [hlayer.col_counts[col] for col in cat_cols[:-1]])

            # The index in weights where we look up the first of the embeddings for this set of columns
            # This lets us concatenate all embeddings into a single weights matrix rather than defining
            # them separately, and deterministicly able to derive the index in this larger weight matrix.
            index = start + tf.math.add_n(
                [
                    # hierarchy[col] is column of dataframe
                    tf.constant(offset, dtype=tf.int64) * tf.cast(hierarchy[col], dtype=tf.int64)
                    for offset, col in zip(offsets, cat_cols)
                ],
                name=f"{name}_indices",
            )
        learned_weights = tf.gather(hlayer.weights, index, name="embeds")
        if len(learned_weights.shape) > 2:
            flattened_shape = prod(learned_weights.shape[:-1])
            learned_weights = tf.reshape(learned_weights, (flattened_shape, n_features))
        output[name] = pd.DataFrame(learned_weights, index=output_index, columns=feature_names)
    return output