TemporalNet

Bases: EconomicNetwork

Source code in wt_ml/networks/temporal_net.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class TemporalNet(EconomicNetwork):
    result_type: ClassVar[type[tf.experimental.ExtensionType]] = TemporalNetIntermediaries
    num_mask_steps: ClassVar[int] = 0

    def build(self, input_shapes):
        """Build the layers needed for temporal net.

        Args:
            input_shapes (Tuple[tf.Tensor, ...]): Tuple of tensor shapes of positional arguments passed to `__call__()`.
        """
        super().build(input_shapes)
        self.baseline_layer = self.hyperparameters.get_submodule(
            "baseline",
            module_type=LinearBaseline,
            kwargs=dict(
                starting_sales=np.array(self.data_encodings["wholesaler_means"], dtype=np.float32)[None, :, :, 0],
                num_starts=self.data_encodings["total_restarts"],
                encodings=self.encodings,
            ),
            help="A piecewise linear curve to serve as the baseline for long term predictions.",
        )

    def get_baseline(
        self, impacts: ImpactsIntermediaries, training=False, debug=False, skip_metrics=False  # noqa: U100
    ) -> tuple[tf.Tensor, dict[str, LinearBaselineIntermediaries]]:
        """Calling the linear baseline layer to compute the baseline

        Args:
            impacts (ImpactsIntermediaries): Additive or multiplicative impacts on baseline
            training (bool, optional): Whether training the model parameters or not. Defaults to False.

        Returns:
            tf.Tensor: calculated baseline
            dict[str,LinearBaselineIntermediaries]: Dictionary of intermediate calculations for baseline,
                                                    like slope, intercept, etc.
        """
        baseline_layer = self.baseline_layer(
            LinearBaselineInput(
                dates_since_start=self.dates_since_start_ph,
                sales_num_restarts=self.sales_num_restarts_ph,
                hierarchy=dict(self.hierarchical_placeholders.items()),
                mask=self.yhat_mask_ph,
            ),
            training=training,
            skip_metrics=skip_metrics,
            debug=debug,
        )
        return baseline_layer.baseline, {"baseline_layer": baseline_layer}

    def create_network_phs(self, batch: EconomicModelInput, training=False):
        """
        Method to convert dataset object data into tensors for network inputs.
        Also handles data objects passed in kwargs externally.
        """
        super().create_network_phs(batch, training=training)
        self.sales_num_restarts_ph = batch.num_restarts

    def python_train_step(
        self, batch, optimizer: tf.optimizers.Optimizer, return_grads: bool = False
    ) -> tuple[dict[str, tf.Tensor], dict[str, tf.Tensor], tf.Tensor]:
        from wt_ml.layers.layer_utils import to_dense

        self.clear()
        for child in self.submodules:
            if isinstance(child, Module):
                child.clear()
        with tf.GradientTape() as tape:
            intermediaries = self(batch, training=True)
            total_loss = self.get_total_loss()
        trn_vars = self.trn_vars
        gradients = tape.gradient(total_loss, trn_vars)
        optimizer.apply_gradients(zip(gradients, trn_vars))
        if self.baseline_layer.use_perfect_adjustment:
            if isinstance(intermediaries, Mapping):
                intermediaries = intermediaries[self.NetTypes[0].__name__.lower()]
            self.baseline_layer.do_perfect_adjustment(batch, intermediaries)
        step = self._step_var.assign_add(1)
        if return_grads:
            gradients_tracker = {
                variable.name: to_dense(grad) for grad, variable in zip(gradients, trn_vars) if grad is not None
            }
            return (self.targets(), self.get_all_losses() | {"loss": total_loss}, step, gradients_tracker)
        else:
            return (self.targets(), self.get_all_losses() | {"loss": total_loss}, step)

build(input_shapes)

Build the layers needed for temporal net.

Parameters:

Name Type Description Default
input_shapes Tuple[Tensor, ...]

Tuple of tensor shapes of positional arguments passed to __call__().

required
Source code in wt_ml/networks/temporal_net.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def build(self, input_shapes):
    """Build the layers needed for temporal net.

    Args:
        input_shapes (Tuple[tf.Tensor, ...]): Tuple of tensor shapes of positional arguments passed to `__call__()`.
    """
    super().build(input_shapes)
    self.baseline_layer = self.hyperparameters.get_submodule(
        "baseline",
        module_type=LinearBaseline,
        kwargs=dict(
            starting_sales=np.array(self.data_encodings["wholesaler_means"], dtype=np.float32)[None, :, :, 0],
            num_starts=self.data_encodings["total_restarts"],
            encodings=self.encodings,
        ),
        help="A piecewise linear curve to serve as the baseline for long term predictions.",
    )

create_network_phs(batch, training=False)

Method to convert dataset object data into tensors for network inputs. Also handles data objects passed in kwargs externally.

Source code in wt_ml/networks/temporal_net.py
74
75
76
77
78
79
80
def create_network_phs(self, batch: EconomicModelInput, training=False):
    """
    Method to convert dataset object data into tensors for network inputs.
    Also handles data objects passed in kwargs externally.
    """
    super().create_network_phs(batch, training=training)
    self.sales_num_restarts_ph = batch.num_restarts

get_baseline(impacts, training=False, debug=False, skip_metrics=False)

Calling the linear baseline layer to compute the baseline

Parameters:

Name Type Description Default
impacts ImpactsIntermediaries

Additive or multiplicative impacts on baseline

required
training bool

Whether training the model parameters or not. Defaults to False.

False

Returns:

Type Description
Tensor

tf.Tensor: calculated baseline

dict[str, LinearBaselineIntermediaries]

dict[str,LinearBaselineIntermediaries]: Dictionary of intermediate calculations for baseline, like slope, intercept, etc.

Source code in wt_ml/networks/temporal_net.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def get_baseline(
    self, impacts: ImpactsIntermediaries, training=False, debug=False, skip_metrics=False  # noqa: U100
) -> tuple[tf.Tensor, dict[str, LinearBaselineIntermediaries]]:
    """Calling the linear baseline layer to compute the baseline

    Args:
        impacts (ImpactsIntermediaries): Additive or multiplicative impacts on baseline
        training (bool, optional): Whether training the model parameters or not. Defaults to False.

    Returns:
        tf.Tensor: calculated baseline
        dict[str,LinearBaselineIntermediaries]: Dictionary of intermediate calculations for baseline,
                                                like slope, intercept, etc.
    """
    baseline_layer = self.baseline_layer(
        LinearBaselineInput(
            dates_since_start=self.dates_since_start_ph,
            sales_num_restarts=self.sales_num_restarts_ph,
            hierarchy=dict(self.hierarchical_placeholders.items()),
            mask=self.yhat_mask_ph,
        ),
        training=training,
        skip_metrics=skip_metrics,
        debug=debug,
    )
    return baseline_layer.baseline, {"baseline_layer": baseline_layer}