时间嵌入选择代码解析
这行代码根据embed_type
参数动态选择使用哪种时间嵌入方法,是一个条件赋值语句(三元表达式):
self.temporal_embedding = TemporalEmbedding(...) if embed_type != 'timeF' else TimeFeatureEmbedding(...)
解析
-
条件判断:
embed_type != 'timeF'
- 检查嵌入类型是否不等于’timeF’
-
如果条件为真(embed_type不是’timeF’):
- 使用
TemporalEmbedding
类,这是一种基于查找表的离散时间特征嵌入 - 适用于将时间信息(月份、日期、星期几等)作为离散类别处理
- 使用
-
如果条件为假(embed_type是’timeF’):
- 使用
TimeFeatureEmbedding
类,这是一种基于线性映射的连续时间特征嵌入 - 适用于将时间特征作为连续数值处理
- 使用
两种嵌入方法的区别
-
TemporalEmbedding:
- 使用查找表(Embedding表)存储每个时间值的嵌入向量
- 将时间视为离散类别(如第1月、第2月…)
- 输出是各时间特征嵌入的和
-
TimeFeatureEmbedding:
- 使用线性层(nn.Linear)直接映射时间特征
- 将时间视为连续特征
- 输入是归一化的时间特征,如一年中的相对位置
这种设计允许模型根据不同的数据特性和实验需求灵活选择最适合的时间表示方法。
TimeFeatureEmbedding 类解析
TimeFeatureEmbedding
类是一个专为连续时间特征设计的嵌入层,与离散类别的 TemporalEmbedding
不同。
形状和操作分析
初始化
def __init__(self, d_model, embed_type='timeF', freq='h'):# freq_map定义不同数据频率下使用的时间特征维度freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}# 根据频率确定输入维度d_inp = freq_map[freq]# 创建无偏置的线性层将d_inp维度映射到d_model维度self.embed = nn.Linear(d_inp, d_model, bias=False)
前向传播
def forward(self, x):# 输入x形状: [B, L, d_inp] - B是批次大小, L是序列长度, d_inp是时间特征数量# 线性变换后输出形状: [B, L, d_model]return self.embed(x)
关键特点
-
连续特征映射:
- 直接对时间特征进行线性变换,而不是像
TemporalEmbedding
那样进行查表操作 - 适用于连续的、已归一化的时间特征
- 直接对时间特征进行线性变换,而不是像
-
频率相关输入维度:
- 根据不同的时间序列频率(
freq
)确定输入维度 - 例如,小时级数据使用4个特征,分钟级使用5个特征
- 根据不同的时间序列频率(
-
形状转换:
- 输入:
[B, L, d_inp]
- 线性映射:
W·x
其中W
是形状为[d_inp, d_model]
的权重矩阵 - 输出:
[B, L, d_model]
- 输入:
这种设计使模型可以直接处理连续的时间特征编码,比如周期性的正弦/余弦表示,而不需要将时间离散化为类别。