数据分析案例DEEPAR

更新DEEPAR预测股票代码

2020-07-22  本文已影响0人  白马闯红灯

Using DeepAR from gluon-ts which published by Amazon. The data was got from www.baostock.com. The code is a reference of https://github.com/samemelody/gluon-ts/blob/master/examples/COV19-forecast.ipynb NOTE: this notebook is for illustration purposes only.

之前做的股票预测效果不是很好,而且也没做测试,这回是用提前20天的数据做forecast,最后效果好的出乎意料,数据来源还是baostock.com,这个网站真的很好用。代码参考awslab写的covid19的一个预测代码,基本就是按着那个扒下来的,这个DEEPAR在预测这种多item的数据时候还是不错的。源代码在github上 有兴趣的可以试试。用的jupyterlab。
最近要帮别人做半年的数据分析,等有空会每天优化一次这个模型,把predictor的结果放到这来,代码就不贴了,贴上图片结果,刚开始玩机器学习,还是挺好玩的。

import pprint
from functools import partial
import baostock as bs
import pandas as pd

from gluonts.dataset.common import ListDataset
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.distribution.piecewise_linear import PiecewiseLinearOutput
from gluonts.evaluation import Evaluator
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
from pathlib import Path
from gluonts.model.predictor import Predictor
from gluonts.distribution.neg_binomial import NegativeBinomialOutput
import matplotlib.pyplot as plt
import json
#tqdm.autonotebook.tqdm
from tqdm.autonotebook import tqdm

Get the data from baostock.com using volume and turn as dynamic_feat

def mygetstockdata(code):
    
   # print('login respond error_code:'+lg.error_code)
   # print('login respond  error_msg:'+lg.error_msg)
    rs = bs.query_history_k_data_plus(code,
        "date,close,volume,turn",
        start_date='2018-01-01', 
        frequency="d", adjustflag="2") #frequency="d"取日k线,adjustflag="3"默认不复权
   
    return rs.get_data()


Randomly choose 10 stock codes to test the Deepar.

liststock =['sh.000001','sz.399001','sz.002068','sz.300324','sh.600754','sz.300439','sh.600862','sh.603991','sh.600081','sz.300026']
prediction_length = 20
#liststock = ['sz.300462','sz.300789']
listtrandic = []
listtestdic = []
lg = bs.login()
for ite in liststock:
    dd = mygetstockdata(ite)
    trandic = {"start":dd.date[0],"target":list(dd.close),"cat":int(ite.split('.')[1]),"dynamic_feat":[list(dd.volume),list(dd.turn)]}
    testdic = {"start":dd.date[0],
                   "target":(dd.close)[:-prediction_length],
                   "cat":int(ite.split('.')[1]),
                   "dynamic_feat":[(dd.volume)[:-prediction_length],(dd.turn)[:-prediction_length]]}
    #strjon = json.dumps(dic)
    listtrandic.append(trandic)
    listtestdic.append(testdic)
bs.logout()

traindata = ListDataset(
    listtrandic,
    freq = "1d"
)

testdata = ListDataset(
    listtestdic,
    freq = "1d"
)

Training

estimator = DeepAREstimator(
    prediction_length=prediction_length,
    context_length=60,
    freq="1d",
    trainer=Trainer(ctx="cpu",
                    epochs=100,#30
                    learning_rate=1e-2,
                    num_batches_per_epoch=300, #100
                   )
)
predictor = estimator.train(traindata)

predictor

from gluonts.evaluation.backtest import make_evaluation_predictions
from tqdm.autonotebook import tqdm
forecast_it, ts_it = make_evaluation_predictions(
    dataset=testdata,
    predictor=predictor,
    num_samples=100
)

print("Obtaining time series conditioning values ...")
tss = list(tqdm(ts_it, total=len(testdata)))
print("Obtaining time series predictions ...")
forecasts = list(tqdm(forecast_it, total=len(testdata)))
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
https://github.com/samemelody/pyworkspace/blob/master/stockpredict.ipynb
上一篇下一篇

猜你喜欢

热点阅读