您的位置:首页 > 财经 > 产业 > 下面不属于网络推广方法_广告网页制作模板_产品软文怎么写_网站页面优化内容包括哪些

下面不属于网络推广方法_广告网页制作模板_产品软文怎么写_网站页面优化内容包括哪些

2025/2/22 14:03:27 来源:https://blog.csdn.net/yweng18/article/details/145717688  浏览:    关键词:下面不属于网络推广方法_广告网页制作模板_产品软文怎么写_网站页面优化内容包括哪些
下面不属于网络推广方法_广告网页制作模板_产品软文怎么写_网站页面优化内容包括哪些

第11集:时间序列预测——循环神经网络(RNN)与 LSTM

在机器学习中,时间序列预测 是一种重要的任务,广泛应用于股票价格预测、天气预报、语音识别等领域。传统的机器学习方法难以捕捉时间序列中的长期依赖关系,而 循环神经网络(Recurrent Neural Networks, RNN) 和其改进版本 长短期记忆网络(LSTM, Long Short-Term Memory) 能够有效解决这一问题。今天我们将深入探讨 RNN 和 LSTM 的原理,并通过实践部分使用 LSTM 对股票价格数据进行预测


时间序列数据的特点

什么是时间序列?

时间序列数据是一组按时间顺序排列的观测值,具有以下特点:

  1. 时序性:数据点之间存在时间依赖关系。
  2. 趋势性:数据可能表现出长期上升或下降的趋势。
  3. 周期性:数据可能具有重复的周期模式(如季节性变化)。
  4. 噪声:数据中可能存在随机波动。

图1:时间序列数据示例
(图片描述:折线图展示了股票价格随时间的变化,包含上升趋势、周期波动和随机噪声。)
在这里插入图片描述


RNN 的基本结构与局限性

RNN 的核心思想

RNN 是一种专门用于处理序列数据的神经网络,其核心思想是引入循环结构,使得当前时刻的输出不仅依赖于当前输入,还依赖于之前的状态。公式如下:
h t = f ( W h h t − 1 + W x x t + b ) h_t = f(W_h h_{t-1} + W_x x_t + b) ht=f(Whht1+Wxxt+b)
其中:

  • h t 是当前时刻的隐藏状态。 h_t 是当前时刻的隐藏状态。 ht是当前时刻的隐藏状态。
  • h t − 1 是上一时刻的隐藏状态。 h_{t-1} 是上一时刻的隐藏状态。 ht1是上一时刻的隐藏状态。
  • x t 是当前时刻的输入。 x_t 是当前时刻的输入。 xt是当前时刻的输入。
  • f 是激活函数(通常为 t a n h 或 R e L U )。 f 是激活函数(通常为 tanh 或 ReLU)。 f是激活函数(通常为tanhReLU)。

图2:RNN 结构示意图
(图片描述:一个简单的 RNN 模型,展示了输入、隐藏状态和输出之间的循环连接。)
在这里插入图片描述

RNN 的局限性

尽管 RNN 能够捕捉短期依赖关系,但在处理长期依赖时容易出现梯度消失或梯度爆炸问题,导致模型性能下降。


LSTM 与 GRU 如何解决长期依赖问题

1. LSTM(Long Short-Term Memory)

LSTM 是 RNN 的改进版本,通过引入门控机制(遗忘门、输入门和输出门),能够更好地捕捉长期依赖关系。公式如下:
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) C t = f t ⋅ C t − 1 + i t ⋅ C ~ t o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) h t = o t ⋅ tanh ⁡ ( C t ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \\ i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \\ \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \\ C_t = f_t \cdot C_{t-1} + i_t \cdot \tilde{C}_t \\ o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \\ h_t = o_t \cdot \tanh(C_t) ft=σ(Wf[ht1,xt]+bf)it=σ(Wi[ht1,xt]+bi)C~t=tanh(WC[ht1,xt]+bC)Ct=ftCt1+itC~tot=σ(Wo[ht1,xt]+bo)ht=ottanh(Ct)
其中:

  • f t 是遗忘门,控制哪些信息需要被遗忘。 f_t 是遗忘门,控制哪些信息需要被遗忘。 ft是遗忘门,控制哪些信息需要被遗忘。
  • i t 是输入门,控制哪些新信息需要被添加。 i_t 是输入门,控制哪些新信息需要被添加。 it是输入门,控制哪些新信息需要被添加。
  • o t 是输出门,控制哪些信息需要被输出。 o_t 是输出门,控制哪些信息需要被输出。 ot是输出门,控制哪些信息需要被输出。

图3:LSTM 单元结构
(图片描述:LSTM 单元内部结构图,展示了遗忘门、输入门、输出门和细胞状态的交互过程。)
在这里插入图片描述

2. GRU(Gated Recurrent Unit)

GRU 是 LSTM 的简化版本,通过合并遗忘门和输入门为更新门,减少了参数数量,同时保留了捕捉长期依赖的能力。


时间序列预测的常见应用场景

  1. 金融领域:股票价格预测、汇率预测。
  2. 气象领域:天气预报、气候变化分析。
  3. 工业领域:设备故障预测、生产计划优化。
  4. 医疗领域:疾病发展趋势预测、患者健康监测。

实践部分:使用 LSTM 对股票价格数据进行预测

数据集简介

我们使用一个公开的股票价格数据集,包含某公司每日的开盘价、收盘价、最高价、最低价和成交量。目标是基于历史数据预测未来一天的收盘价。

完整代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.optimizers import Adam# 加载数据
url = "https://raw.githubusercontent.com/mwitiderrick/stockprice/master/NSE-TATAGLOBAL.csv"
data = pd.read_csv(url)
data['Date'] = pd.to_datetime(data['Date'])
data = data.sort_values('Date')
prices = data['Close'].values.reshape(-1, 1)# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_prices = scaler.fit_transform(prices)# 创建训练集和测试集
def create_dataset(dataset, time_step=60):X, y = [], []for i in range(len(dataset) - time_step - 1):X.append(dataset[i:i + time_step, 0])y.append(dataset[i + time_step, 0])return np.array(X), np.array(y)time_step = 60
X, y = create_dataset(scaled_prices, time_step)
X = X.reshape(X.shape[0], X.shape[1], 1)train_size = int(len(X) * 0.8)
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]# 构建 LSTM 模型
model = Sequential([LSTM(50, return_sequences=True, input_shape=(time_step, 1)),LSTM(50, return_sequences=False),Dense(25),Dense(1)
])# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001), loss='mean_squared_error')# 训练模型
history = model.fit(X_train, y_train, epochs=20, batch_size=64, validation_data=(X_test, y_test))# 预测
predictions = model.predict(X_test)
predictions = scaler.inverse_transform(predictions)
y_test_actual = scaler.inverse_transform(y_test.reshape(-1, 1))# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(y_test_actual, label='True Prices', color='blue')
plt.plot(predictions, label='Predicted Prices', color='red')
plt.title('Stock Price Prediction using LSTM', fontsize=16)
plt.xlabel('Time', fontsize=12)
plt.ylabel('Price', fontsize=12)
plt.legend()
plt.grid()
plt.show()# 输出损失曲线
plt.figure(figsize=(8, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss Curve', fontsize=16)
plt.xlabel('Epochs', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend()
plt.grid()
plt.show()

运行结果

2025-02-21 01:30:45.490392: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-21 01:30:48.638446: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-21 01:30:56.949151: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
D:\python_projects\music_player\Lib\site-packages\keras\src\layers\rnn\rnn.py:200: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.super().__init__(**kwargs)
Epoch 1/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 3s 42ms/step - loss: 0.0218 - val_loss: 0.0360
Epoch 2/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 0.0016 - val_loss: 0.0037
Epoch 3/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 27ms/step - loss: 6.3420e-04 - val_loss: 0.0022
Epoch 4/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 5.9082e-04 - val_loss: 0.0031
Epoch 5/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 5.3243e-04 - val_loss: 0.0024
Epoch 6/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 34ms/step - loss: 5.2398e-04 - val_loss: 0.0022
Epoch 7/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 30ms/step - loss: 5.4002e-04 - val_loss: 0.0024
Epoch 8/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 4.6474e-04 - val_loss: 0.0020
Epoch 9/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 4.4463e-04 - val_loss: 0.0028
Epoch 10/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - loss: 4.5939e-04 - val_loss: 0.0019
Epoch 11/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 4.4314e-04 - val_loss: 0.0019
Epoch 12/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 30ms/step - loss: 4.3458e-04 - val_loss: 0.0026
Epoch 13/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 4.2457e-04 - val_loss: 0.0024
Epoch 14/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 4.0447e-04 - val_loss: 0.0034
Epoch 15/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 27ms/step - loss: 3.8518e-04 - val_loss: 0.0027
Epoch 16/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 3.7076e-04 - val_loss: 0.0024
Epoch 17/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 3.7536e-04 - val_loss: 0.0018
Epoch 18/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 3.7514e-04 - val_loss: 0.0020
Epoch 19/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 3.7502e-04 - val_loss: 0.0041
Epoch 20/20
25/25 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 3.6098e-04 - val_loss: 0.0025
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 25ms/step 
预测结果可视化

图4:真实价格与预测价格对比
(图片描述:折线图展示了真实股票价格(蓝色)与 LSTM 模型预测价格(红色)的对比,两条曲线较为接近,表明模型表现良好。)
在这里插入图片描述

损失曲线

图5:训练与验证损失曲线
(图片描述:折线图展示了训练和验证损失随 epoch 的变化,两条曲线均逐渐下降并趋于平稳。)
在这里插入图片描述


总结

本文介绍了时间序列数据的特点、RNN 的基本结构及其局限性,并详细讲解了 LSTM 如何解决长期依赖问题。通过实践部分,我们成功使用 LSTM 对股票价格数据进行了预测。希望这篇文章能帮助你更好地理解时间序列预测的基本原理!

下集预告:机器学习实战(12):项目实战——端到端的机器学习项目


参考资料

  • TensorFlow 文档: https://www.tensorflow.org/
  • 股票价格数据集: https://github.com/mwitiderrick/stockprice

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com