def get_hier_params_df(hlayer: HierchicalEmbedding) -> dict[str, pd.DataFrame]:
"""Get the learned weights for a HierarchicalEmbedding layer"""
n_features = hlayer.weights.shape[-1]
feature_names = hlayer.feature_names
if feature_names is None:
feature_names = list(range(n_features))
if len(feature_names) != n_features:
if n_features % len(feature_names) == 0:
num_dups = int(n_features / len(feature_names))
feature_names = [f"{name}_{i+1}" for name in feature_names for i in range(num_dups)]
else:
raise ValueError(f"feature_names must be a list of size {n_features}, but got size {len(feature_names)}")
output = dict()
for col_names in hlayer.columns:
if isinstance(col_names, str):
# We want to assume col_names is a list of column names
col_names = [col_names]
num_cols = len(col_names)
cat_cols = [col for col in col_names if hlayer.encodings[col] != "continuous"]
cont_cols = [col for col in col_names if hlayer.encodings[col] == "continuous"]
num_cat_cols = len(cat_cols)
num_cont_cols = num_cols - num_cat_cols
if num_cat_cols == 0:
hierarchy = {cont_cols[0]: np.asarray([1.0])}
output_index = [cont_cols[0]]
else:
midx = pd.MultiIndex.from_product([hlayer.encodings[c].values() for c in cat_cols], names=cat_cols)
output_index = pd.MultiIndex.from_product([hlayer.encodings[c].keys() for c in cat_cols], names=cat_cols)
if num_cont_cols > 0:
midx = pd.concat({1.0: pd.DataFrame(index=midx)}, names=[cont_cols[0]]).index
hierarchy = {h: midx.get_level_values(h).to_numpy() for h in midx.names}
name = hlayer.stitched_cols(col_names)
# The start of the region for this weight
start = hlayer.offsets[name]
if num_cont_cols == 0:
shape = tf.shape(hierarchy[cat_cols[0]])
weight = tf.ones(shape, dtype=tf.float32, name=f"{name}_weights")
else:
weight = hierarchy[cont_cols[0]]
if num_cat_cols == 0:
index = tf.cast(
tf.fill(tf.shape(weight), hlayer.offsets[name]),
dtype=tf.int64,
name=f"{name}_indices",
)
# if no categorical columns, we have the value of continuous column as the index
else:
# The standard encoding of left to right indices given base col_counts[col] for each col
offsets = np.cumprod([1] + [hlayer.col_counts[col] for col in cat_cols[:-1]])
# The index in weights where we look up the first of the embeddings for this set of columns
# This lets us concatenate all embeddings into a single weights matrix rather than defining
# them separately, and deterministicly able to derive the index in this larger weight matrix.
index = start + tf.math.add_n(
[
# hierarchy[col] is column of dataframe
tf.constant(offset, dtype=tf.int64) * tf.cast(hierarchy[col], dtype=tf.int64)
for offset, col in zip(offsets, cat_cols)
],
name=f"{name}_indices",
)
learned_weights = tf.gather(hlayer.weights, index, name="embeds")
if len(learned_weights.shape) > 2:
flattened_shape = prod(learned_weights.shape[:-1])
learned_weights = tf.reshape(learned_weights, (flattened_shape, n_features))
output[name] = pd.DataFrame(learned_weights, index=output_index, columns=feature_names)
return output