custom_reducer(parent_key, child_key)

Custom reducer function for flatten_dict.flatten

Source code in wt_ml/output/utils/intermediary_tracker_utils.py
26
27
28
29
30
31
32
33
34
35
def custom_reducer(parent_key: str, child_key: str) -> str:
    """Custom reducer function for flatten_dict.flatten"""
    if isinstance(child_key, int):
        separator = "_"
    else:
        separator = "."
    if parent_key:
        return f"{parent_key}{separator}{child_key}"
    else:
        return str(child_key)

get_intermediaries_tracker(model, dataset, intermediaries_trackers, batch_intermediary_processor=None, epoch_or_step='epoch', debug=True)

Get the intermediary tracker function for the given model, dataset, and intermediary trackers. This ensures that full data model intermediaries are calculated only once for all the trackers. Each tracker is expected to have a batch processor to process batch intermediaries and a post processor to process concatenated batch results.

Source code in wt_ml/output/utils/intermediary_tracker_utils.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def get_intermediaries_tracker(
    model: TrainableModule,
    dataset: EconomicDataset,
    intermediaries_trackers: dict[str, dict[str, Callable]],
    batch_intermediary_processor: Callable[[Any], Any] = None,
    epoch_or_step: Literal["epoch", "step"] = "epoch",
    debug: bool = True,
):
    """
    Get the intermediary tracker function for the given model, dataset, and intermediary trackers.
    This ensures that full data model intermediaries are calculated only once for all the trackers.
    Each tracker is expected to have a batch processor to process batch intermediaries and a post processor to
    process concatenated batch results."""

    def tracker_fn() -> dict[str, Any]:
        dataset.reset_rng()
        batch_intermediaries = []
        batch_results = []

        for i, batch in enumerate(dataset):
            batch_output = model(batch, training=False, debug=debug)["temporalnet"].to_cpu()

            if batch_intermediary_processor is not None:
                batch_intermediary = batch_intermediary_processor(batch_output)
            else:
                batch_intermediary = batch_output
            batch_intermediaries.append(batch_intermediary)

        with ThreadPoolExecutor(max_workers=None) as executor:
            batch_processors = [
                executor.submit(
                    intermediaries_trackers["batch_processor"], deepcopy(batch_intermediary), **{"batch_id": batch_id}
                )
                for batch_id, batch_intermediary in enumerate(batch_intermediaries)
            ]
            for batch_processor in as_completed(batch_processors):
                batch_results.append(batch_processor.result())

        return intermediaries_trackers["post_processor"](batch_results) | {
            "step": model.step if epoch_or_step == "step" else model.epoch
        }

    _validate_intermediaries_trackers(intermediaries_trackers)
    return tracker_fn

intermediary_filter(intermediary)

Filter the intermediary names based on the inclusion and exclusion conditions

Source code in wt_ml/output/utils/intermediary_tracker_utils.py
73
74
75
76
77
def intermediary_filter(intermediary: str) -> bool:
    """Filter the intermediary names based on the inclusion and exclusion conditions"""
    inclusion_condition = any(name in intermediary for name in ARRS_TO_INCLUDE)
    exclusion_condition = not any(name in intermediary for name in ARRS_TO_SKIP)
    return exclusion_condition and inclusion_condition

map_layer_to_intermediary_names(inte_names)

Map the intermediary names to the layer names

Source code in wt_ml/output/utils/intermediary_tracker_utils.py
65
66
67
68
69
70
def map_layer_to_intermediary_names(inte_names: list[str]) -> dict[str, list[str]]:
    """Map the intermediary names to the layer names"""
    layer_map = defaultdict(list)
    for name in inte_names:
        layer_map[split_layer_param(name)[0]].append(name)
    return layer_map

split_layer_param(name)

Split the intermediary name into layer and parameter

Source code in wt_ml/output/utils/intermediary_tracker_utils.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def split_layer_param(name: str) -> tuple[str, str]:
    """Split the intermediary name into layer and parameter"""
    if "impacts." in name:
        name = remove_impacts_prefix(name, "impacts.")
    parts = name.split(".")
    if "roicurve." in name:
        if len(parts) == 2:
            layer = parts[0]
            param = parts[1]
        else:
            layer, param = ".".join(parts[:2]), ".".join(parts[2:] if len(parts) == 3 else parts[1:])
    else:
        layer, param = parts[0], ".".join(parts[1:])
    if "baseline_layer" in name:
        layer, param = "baseline_layer", name.replace("baseline_layer.", "")

    if param.startswith("impact_by_signal."):
        param = remove_impacts_prefix(param, "impact_by_signal.")
    return layer, param