RNN实现股价预测

2023-01-31  本文已影响0人  y_7539
import pandas as pd
import numpy as np
df = pd.read_csv("datas/zgpa_train.csv")
df.head()
image.png
price = df["close"]
# 归一化
price_norm = price/max(price)

import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(5,3))
plt.plot(price)
plt.xlabel("time")
plt.ylabel("price")
plt.show()
image.png
# 提取x和y
def extract_data(data, time_step):
    X = []
    y = []
    for i in range(len(data) - time_step):
        X.append([a for a in data[i: i+time_step]])
        y.append(data[i+time_step])
    X = np.array(X)
    X = X.reshape(X.shape[0], X.shape[1], 1)
    return X, y

#样本大小
time_step=8

# 定义x和y  用前八位预测第九位
X, y = extract_data(price_norm, time_step)

from keras.models import Sequential
from keras.layers import Dense, SimpleRNN

#建立模型
model = Sequential()
#添加rnn层
model.add(SimpleRNN(units=5, input_shape=(time_step, 1), activation="relu"))
#输出层
model.add(Dense(units=1, activation="linear"))
#模型配置
model.compile(optimizer="adam", loss="mean_squared_error")

#模型训练  损失不变可以重新载入模型
model.fit(X, np.array(y), batch_size=30, epochs=200)

#预测训练数据
y_train_predict = model.predict(X) * max(price)
y_train = [i * max(price) for i in y]

plt.figure(figsize=(5,3))
plt.plot(y_train_predict, label="predict price")
plt.plot(y_train, label="true price")
plt.xlabel("time")
plt.ylabel("price")
plt.legend()
plt.show()
image.png
#预测测试数据
test_data = pd.read_csv("datas/zgpa_test.csv")
test_data.head()
image.png
price_test = test_data["close"]
#归一化 统一分母
price_test_norm = price_test/max(price)
x_test_norm, y_test_norm = extract_data(price_test_norm, time_step)
# 预测测试数据
y_test_predict = model.predict(x_test_norm) * max(price)
y_test = [i*max(price) for i in y_test_norm]

plt.figure(figsize=(5,3))
plt.plot(y_test_predict, label="test predict price")
plt.plot(y_test, label="test true price")
plt.xlabel("time")
plt.ylabel("price")
plt.legend()
plt.show()
image.png
#存储数据
result_y_test = np.array(y_test).reshape(-1, 1)
result_y_test_predict = y_test_predict
print(result_y_test.shape, result_y_test_predict.shape)
#合并数组
result = np.concatenate((result_y_test, result_y_test_predict), axis=1)
result = pd.DataFrame(result, columns=["real_price_test", "predict_price_test"])
result.to_csv("zgpa_predict_test.csv")
#预测结果会慢一步
上一篇下一篇

猜你喜欢

热点阅读