class EconomicModelInput(AnnotatedExtensionTypeWithShape, tf.experimental.ExtensionType):
"""
The input class used to prepare batches of data.
Attributes:
no_prediction_mask (tf.Tensor): Mask for places we don't want to predict or train on.
no_train_mask (tf.Tensor): Mask for places we don't want to train on. We also won't train
anywhere we don't do predictions.
feature_masks (tf.Tensor): Stacked masks for places where we don't want to train, but think
that due to unforeseen externalities we should perfectly predict. Each mask will
attribute the prediction error to a different driver.
"""
# Required (without massive refactor) axis indices
date_index: Annotated[tf.Tensor, TensorMetadata((Time,), np.int32)]
# Required (without massive refactor) categorical hierarchy params
state_index: Annotated[tf.Tensor, TensorMetadata((Batch,), np.int32)]
wholesaler_index: Annotated[tf.Tensor, TensorMetadata((Batch,), np.int32)]
brand_index: Annotated[tf.Tensor, TensorMetadata((Batch,), np.int32)]
granularity_index: Annotated[tf.Tensor, TensorMetadata((Batch,), np.int32)]
# continuous hierarchy params
continuous_hier_params: Mapping[str, Annotated[tf.Tensor, TensorMetadata((Batch,), np.float32)]] = {}
# target
true_sales: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
# TODO(@ruler501): This could be replaced with a property that evaluates to true_sales / (price + EPSILON) right?
true_volume: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
# axis indices
vehicle_index: Annotated[Optional[tf.Tensor], TensorMetadata((Vehicle,), np.int32)] = None
global_index: Annotated[Optional[tf.Tensor], TensorMetadata((GlobalSignal,), np.int32)] = None
weather_index: Annotated[Optional[tf.Tensor], TensorMetadata((WeatherSignal,), np.int32)] = None
temperature_index: Annotated[Optional[tf.Tensor], TensorMetadata((TemperatureSignal,), np.int32)] = None
holiday_index: Annotated[Optional[tf.Tensor], TensorMetadata((HolidaySignal,), np.int32)] = None
price_dev_index: Annotated[Optional[tf.Tensor], TensorMetadata((PriceDev,), np.int32)] = None
price_ratio_index: Annotated[Optional[tf.Tensor], TensorMetadata((PriceRatio,), np.int32)] = None
distribution_index: Annotated[Optional[tf.Tensor], TensorMetadata((Distribution,), np.int32)] = None
national_trend_index: Annotated[Optional[tf.Tensor], TensorMetadata((NationalTrend,), np.int32)] = None
regional_trend_index: Annotated[Optional[tf.Tensor], TensorMetadata((RegionalTrend,), np.int32)] = None
# categorical hierarchy indices
product_index: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.int32)] = None
region_index: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.int32)] = None
full_vehicle_index: Annotated[Optional[tf.Tensor], TensorMetadata((Vehicle,), np.int32)] = None
parent_vehicle_index: Annotated[Optional[tf.Tensor], TensorMetadata((Vehicle,), np.int32)] = None
global_parent_index: Annotated[Optional[tf.Tensor], TensorMetadata((GlobalSignal,), np.int32)] = None
weather_parent_index: Annotated[Optional[tf.Tensor], TensorMetadata((WeatherSignal,), np.int32)] = None
temperature_parent_index: Annotated[Optional[tf.Tensor], TensorMetadata((TemperatureSignal,), np.int32)] = None
feature_mask_index: Annotated[Optional[tf.Tensor], TensorMetadata((FeatureMask,), np.int32)] = None
# masks and weights
no_prediction_mask: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.bool_)] = None
no_train_mask: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.bool_)] = None
feature_masks: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, FeatureMask), np.bool_)] = None
before_2021_mask: Annotated[Optional[tf.Tensor], TensorMetadata((Time,), np.bool_)] = None
instability_loss_mult: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None
# normalization_factors
price_normalization: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None
distribution_means: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Distribution), np.float32)] = None
# features
distributions: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, Distribution), np.float32)] = None
price: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
imputed_price: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
price_devs: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, PriceDev), np.float32)] = None
vehicle_spends: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, Vehicle), np.float32)] = None
global_signals: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, GlobalSignal), np.float32)] = None
weather_signals: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, WeatherSignal), np.float32)] = None
temperature_signals: Annotated[
Optional[tf.Tensor], TensorMetadata((Batch, Time, TemperatureSignal), np.float32)
] = None
holiday_signals: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, HolidaySignal), np.float32)] = None
national_trend: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, NationalTrend), np.float32)] = None
regional_trend: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, RegionalTrend), np.float32)] = None
yearly_week_number: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.int32)] = None
price_ratios: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time, PriceRatio), np.float32)] = None
# Non-hierarchical metadata
brand_size: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None
investment_axis_scale: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Vehicle), np.float32)] = None
num_restarts: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.int32)] = None
weeks_since_restart: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
maco_cost: Annotated[Optional[tf.Tensor], TensorMetadata((Batch, Time), np.float32)] = None
preinvestment_slope: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None
preinvestment_intercept: Annotated[Optional[tf.Tensor], TensorMetadata((Batch,), np.float32)] = None