EconomicModelInput

Bases: AnnotatedExtensionTypeWithShape, ExtensionType

The input class used to prepare batches of data.

Attributes:

Name Type Description
no_prediction_mask Tensor

Mask for places we don't want to predict or train on.

no_train_mask Tensor

Mask for places we don't want to train on. We also won't train anywhere we don't do predictions.

feature_masks Tensor

Stacked masks for places where we don't want to train, but think that due to unforeseen externalities we should perfectly predict. Each mask will attribute the prediction error to a different driver.

Source code in wt_ml/dataset/data_pipeline.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class EconomicModelInput(AnnotatedExtensionTypeWithShape, tf.experimental.ExtensionType):
    """
    The input class used to prepare batches of data.

    Attributes:
        no_prediction_mask (tf.Tensor): Mask for places we don't want to predict or train on.
        no_train_mask (tf.Tensor): Mask for places we don't want to train on. We also won't train
                anywhere we don't do predictions.
        feature_masks (tf.Tensor): Stacked masks for places where we don't want to train, but think
                that due to unforeseen externalities we should perfectly predict. Each mask will
                attribute the prediction error to a different driver.
    """

    # Required (without massive refactor) axis indices
    date_index: Annotated[tf.Tensor, TensorMetadata((Time,), np.int32)]
    # Required (without massive refactor) categorical hierarchy params
    state_index: Annotated[tf.Tensor, TensorMetadata((Batch,), np.int32)]
    wholesaler_index: Annotated[tf.Tensor, TensorMetadata((Batch,), np.int32)]
    brand_index: Annotated[tf.Tensor, TensorMetadata((Batch,), np.int32)]
    granularity_index: Annotated[tf.Tensor, TensorMetadata((Batch,), np.int32)]
    # continuous hierarchy params
    continuous_hier_params: Mapping[str, Annotated[tf.Tensor, TensorMetadata((Batch,), np.float32)]] = {}
    # target
    true_sales: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
    # TODO(@ruler501): This could be replaced with a property that evaluates to true_sales / (price + EPSILON) right?
    true_volume: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
    # axis indices
    vehicle_index: Annotated[Optional[tf.Tensor], TensorMetadata((Vehicle,), np.int32)] = None
    global_index: Annotated[Optional[tf.Tensor], TensorMetadata((GlobalSignal,), np.int32)] = None
    weather_index: Annotated[Optional[tf.Tensor], TensorMetadata((WeatherSignal,), np.int32)] = None
    temperature_index: Annotated[Optional[tf.Tensor], TensorMetadata((TemperatureSignal,), np.int32)] = None
    holiday_index: Annotated[Optional[tf.Tensor], TensorMetadata((HolidaySignal,), np.int32)] = None
    price_dev_index: Annotated[Optional[tf.Tensor], TensorMetadata((PriceDev,), np.int32)] = None
    price_ratio_index: Annotated[Optional[tf.Tensor], TensorMetadata((PriceRatio,), np.int32)] = None
    distribution_index: Annotated[Optional[tf.Tensor], TensorMetadata((Distribution,), np.int32)] = None
    national_trend_index: Annotated[Optional[tf.Tensor], TensorMetadata((NationalTrend,), np.int32)] = None
    regional_trend_index: Annotated[Optional[tf.Tensor], TensorMetadata((RegionalTrend,), np.int32)] = None
    # categorical hierarchy indices
    product_index: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.int32)] = None
    region_index: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.int32)] = None
    full_vehicle_index: Annotated[Optional[tf.Tensor], TensorMetadata((Vehicle,), np.int32)] = None
    parent_vehicle_index: Annotated[Optional[tf.Tensor], TensorMetadata((Vehicle,), np.int32)] = None
    global_parent_index: Annotated[Optional[tf.Tensor], TensorMetadata((GlobalSignal,), np.int32)] = None
    weather_parent_index: Annotated[Optional[tf.Tensor], TensorMetadata((WeatherSignal,), np.int32)] = None
    temperature_parent_index: Annotated[Optional[tf.Tensor], TensorMetadata((TemperatureSignal,), np.int32)] = None
    feature_mask_index: Annotated[Optional[tf.Tensor], TensorMetadata((FeatureMask,), np.int32)] = None
    # masks and weights
    no_prediction_mask: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.bool_)] = None
    no_train_mask: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.bool_)] = None
    feature_masks: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, FeatureMask), np.bool_)] = None
    before_2021_mask: Annotated[Optional[tf.Tensor], TensorMetadata((Time,), np.bool_)] = None
    instability_loss_mult: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None
    # normalization_factors
    price_normalization: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None
    distribution_means: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Distribution), np.float32)] = None
    # features
    distributions: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, Distribution), np.float32)] = None
    price: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
    imputed_price: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
    price_devs: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, PriceDev), np.float32)] = None
    vehicle_spends: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, Vehicle), np.float32)] = None
    global_signals: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, GlobalSignal), np.float32)] = None
    weather_signals: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, WeatherSignal), np.float32)] = None
    temperature_signals: Annotated[
        Optional[tf.Tensor], TensorMetadata((Batch, Time, TemperatureSignal), np.float32)
    ] = None
    holiday_signals: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, HolidaySignal), np.float32)] = None
    national_trend: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, NationalTrend), np.float32)] = None
    regional_trend: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, RegionalTrend), np.float32)] = None
    yearly_week_number: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.int32)] = None
    price_ratios: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, PriceRatio), np.float32)] = None
    # Non-hierarchical metadata
    brand_size: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None
    investment_axis_scale: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Vehicle), np.float32)] = None
    num_restarts: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.int32)] = None
    weeks_since_restart: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
    maco_cost: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
    preinvestment_slope: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None
    preinvestment_intercept: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None