def create_lead_lag_signal(time_signal: tf.Tensor, n_leads: int, n_lags: int) -> tf.Tensor:
"""Generate multiple lead/lag signals for the given `time_signal` for each `n_leads` and `n_lags` period.
Args:
time_signal (tf.Tensor): Input time signal tensor with shape (batch_size, time_steps, num_signals).
n_leads (int): Number of lead periods to create. If n_leads <=0, no lead_signals will be created.
n_lags (int): Number of lag periods to create. If n_lags <=0, no lag_signals will be created.
Returns:
tf.Tensor: Output signal tensor with shape (batch_size, time_steps, num_signals, (n_lags + 1 + n_leads)).
The signal contains the lag signals followed by original time signal, and the lead signals.
Example:
batch_size = 2
time_steps = 10
num_signals = 3
n_leads = 2
n_lags = 1
time_signal = tf.constant(np.random.rand(batch_size, time_steps, num_signals), dtype=tf.float32)
output_signal = create_lead_lag_signal(time_signal, n_leads, n_lags)
assert output_signal.shape == (batch_size, time_steps, num_signals, (n_lags + 1 + n_leads)))
"""
if n_leads <= 0 and n_lags <= 0:
return tf.expand_dims(time_signal, -1)
# batch_size, time_steps, num_signals
num_time = tf.shape(time_signal)[1]
# Pad the time signal with zeros for max lead and lag
# NOTE: we add extra padding for n_leads to deal with slicing (in [:-1], -1 will not be included).
# if your signal is [1,2,3,4]
# padded signal [0,0,0,0,1,2,3,4,0,0,0], for n_leads=3, n_lags=4
padded_signal = tf.pad(time_signal, [[0, 0], [n_lags, n_leads + 1], [0, 0]], name="padded_signal")
# Create lead signals for each lead period
# padded signal [0,0,0,1,2,3,4,0,0,0]
# lead_perod ∈ [1,n_leads]. Your window is time_steps.
# Shift window to the right, starting from signal '1' position, by lead_period, along padded signal.
# Eg. lead_period = 2; lead_signal = [3,4,0,0]
lead_signals = [
padded_signal[:, n_lags + lead_period : n_lags + lead_period + num_time]
for lead_period in range(1, n_leads + 1)
]
# Create lag signals for each lag period
# padded signal [0,0,0,1,2,3,4,0,0,0]
# lag_period ∈ [1,n_lags]. Your window is time_steps.
# Shift window to the left, starting from signal '1' position, by lag_period, along padded signal.
# Eg. lag_period = 2; lag_signal = [0,0,1,2]
lag_signals = [
padded_signal[:, (n_lags - lag_period) : (n_lags - lag_period + num_time)]
for lag_period in range(n_lags, 0, -1)
]
combined_signals = lag_signals + [time_signal] + lead_signals
# Stack the lag, original, and lead signals along the last axis
output_signal = tf.stack(combined_signals, axis=-1, name="combined_lead_lag_output")
# Now we have signals with
# [shift_right(original_signal, by=n_lags,...,1), original_signal, shift_left(original_signal, 1,...,n_leads)]
# or signal_-nth_lag,..signal_-1,signal,signal_+1,..,singal+nth_lead
return output_signal