get_comparison_df(df)

Compares the hyperparameters at the start and end of a negative feedback run.

Source code in wt_ml/negative_feedback/nf_inspect.py
33
34
35
36
37
38
39
40
41
42
def get_comparison_df(df: pd.DataFrame) -> pd.DataFrame:
    """Compares the hyperparameters at the start and end of a negative feedback run."""
    original_values = df["current_hyperparam"].iloc[0]
    final_values = df["current_hyperparam"].iloc[-1]
    comparison_df = pd.DataFrame({"original": original_values, "new": final_values})
    comparison_df["changefrac"] = comparison_df["new"] / comparison_df["original"]
    comparison_df["logchange"] = np.log(comparison_df["changefrac"])
    comparison_df["abslogchange"] = np.abs(comparison_df["logchange"])
    comparison_df = comparison_df.sort_values("abslogchange", ascending=False)
    return comparison_df

load_overfit_lambda_feedback(nf_data_file)

Loads the overfit lambda feedback data from a CSV file.

Source code in wt_ml/negative_feedback/nf_inspect.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def load_overfit_lambda_feedback(nf_data_file: str | Path) -> pd.DataFrame:
    """Loads the overfit lambda feedback data from a CSV file."""
    df_history_raw = pd.read_csv(nf_data_file)

    # Function to convert string representation of a dictionary back into a dictionary
    def parse_dict_column(column):
        return column.apply(eval)

    # Automatically identify columns with dictionary-like strings
    columns_to_parse = [col for col in df_history_raw.columns if df_history_raw[col].iloc[0].startswith("{")]
    parsed_dfs = {}
    for column in columns_to_parse:
        df_history_raw[column] = parse_dict_column(df_history_raw[column])
        parsed_dfs[column] = pd.json_normalize(df_history_raw[column])

    # Combine these DataFrames with multi-index
    combined_df = pd.concat([parsed_dfs[col] for col in columns_to_parse], axis=1, keys=columns_to_parse)
    combined_df.columns.names = ["type", "signal"]
    combined_df.index.name = "iter"
    return combined_df

plot_impacts(model, dataset, filename)

Plots the country-level impacts of the model.

Source code in wt_ml/negative_feedback/nf_inspect.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def plot_impacts(model, dataset, filename):
    """Plots the country-level impacts of the model."""
    outputs_for_decomp = get_data_all_batches(
        dataset,
        gt_model=model,
        encodings=dataset.encodings,
        include_rois=False,
        collapse_level=0,
    )
    impact_df_for_decomp = outputs_for_decomp[3]
    y_df_for_decomp = outputs_for_decomp[4]
    all_impacts_for_decomp = pd.concat([impact_df_for_decomp, y_df_for_decomp], axis=1)
    country_level_impacts_for_decomp = aggregate_impact_df(all_impacts_for_decomp)
    _ = OutputImpact(all_impacts_df=country_level_impacts_for_decomp).visualize(show_plots=False, file_name=filename)