股票逼近

2018-09-09  本文已影响1人  上行彩虹人
import tushare
import codecs
#计算用
import numpy as np
import matplotlib.pyplot as plot
import matplotlib.dates as dates
import fileinput
import tkinter as tk
from tkinter import *
from tkinter import Button
from tkinter import Text
from tkinter import Entry
from tkinter import messagebox
import datetime #获取今日时间


#y=w1*x1+w2*x2+w3*x3+w4*x4+w5*x5 +b
#w1,w2,w3,w4,w5初始值为0.2 b初始值为0.01
w1=0.2
w2=0.2
w3=0.2
w4=0.2
w5=0.2
b =0.01

#learn_rate学习速度
learn_rate=0.00001
data=np.random.rand(10)
date=np.random.rand(10)

##GUI
def Test():
    ok = 0
    t=text.get()
    if t=='':
        messagebox.showinfo('提示','输入不能为空')
        ok = 1
    else:
        ok = 0
    for i in range(len(t)):
        # 97 122 48 57
        if (t[i]<'0' or t[i]>'9') and (t[i]<'a' and t[i]>'z'):
            print(t[i])
            ok = 1

    print('ok',ok)
    if ok == 1:
        messagebox.showinfo('提示', '输入股票代码由数字或者字母')
    if ok == 0 :
        print(t)
        #获取数据
        success=get_data_online(t)
        if success == 1:
            get_data()
            print(data)
            #机器学习
            pridicit=traing()
            #显示数据
            next_day = get_nextday_date()
            res='DATE: '+ str(next_day)+' INDEX: '+t +' TOMORROW: '+str(pridicit)
            messagebox.showinfo('预测结果:', res)
            plot.title(res)
            plot.plot_date(date, data, fmt='-', marker='o', c='g')
            #条形图
            pl=plot.bar(x=date,height=data,color='red')
            plot.show()


#使用tushare获取大盘数据
def get_data_online(t):
    #sh000001
    #获得今日日期
    s=get_today_date()
    data = tushare.get_hist_data(t, start='2018-03-10',end=s)
    file=open('data.csv','w')
    #返回为空处理
    try:
        file.write(str(data['close']))
    except :
        print('Error: 输入股票代码不存在')
        messagebox.showinfo('提示', '输入股票代码不存在')
        return 0

    file.close()

    #去掉最后一行垃圾数据
    count=len(open('data.csv','rU').readlines())
    print('count',count)
    f=fileinput.input('data.csv',inplace=True)
    for index in f:
        if f.filelineno()==count:
            print('')
        else:
            print(index,end='')

    #print(data['close'])
    return 1

#输入窗口居中
def center_win(root, width, height):
    swidth = root.winfo_screenwidth()
    sheight = root.winfo_screenheight()
    winsize = '%dx%d+%d+%d' % (width, height, (swidth - width)/2, (sheight - height)/2)
    #print(size)
    root.geometry(winsize)
    return


def get_data():
    global date
    date=np.loadtxt('data.csv',delimiter='    ',
                      converters={0: dates.bytespdate2num("%Y-%m-%d")},usecols=(0),
                      skiprows=(1),unpack=True)
    global data
    data = np.loadtxt('data.csv',delimiter='    ',usecols=(1),unpack=True,dtype=str,skiprows=(1))
    #print(date)
    # plot.plot_date(date,data,fmt='-',marker='o',c='r')
    # plot.show()
    return


#获得今日日期
def get_today_date():
    year = datetime.datetime.now().year
    mounth=datetime.datetime.now().month
    day = datetime.datetime.now().day
    if mounth < 10:
        mounth = str(0) + str(mounth)
    if day < 10:
        day = str(0) + str(day)
    now_day = str(year) + '-' + str(mounth)+str(day)
    return now_day

#获得下一日期日期
def get_nextday_date():
    now_day  = datetime.datetime.now()
    year = datetime.datetime.now().year
    mounth = datetime.datetime.now().month
    day = datetime.datetime.now().day
    #日期间隔一天
    detaday  = datetime.timedelta(days=1)
    next_day = now_day + detaday
    next_day = next_day.strftime('%Y-%m-%d')
    return next_day

#界面显示
def show():
    next_day = get_nextday_date()
    res = '明日:' + str(next_day) + '股票:' + t + '预计明日股价:' + str(pridicit)
    messagebox.showinfo('预测结果:', res)
    plot.plot_date(date, data, fmt='-', marker='o', c='r')
    # 条形图
    pl = plot.bar(x=date, height=data, color='red')
    plot.show()
    return

#训练
def traing():
    global w1,w2,w3,w4,w5,b
    global data,date
    global learn_rate
    data=data
    print(data[-0])
    #x=float(data[1])+float(data[2])
    #print(x)
    #数组长度
    l=len(data)-6

    for i in range(l):
        y=w1*float(data[-i-1])+w2*float(data[-i-2])\
          +w3*float(data[-i-3])+w4*float(data[-i-4])+w5*float(data[-i-5])+b
        print('y',y)
        loss =  (y-float(data[-i-6]))*(y-float(data[-i-6]))
        print('loss:',loss)

        #loss误差大于5,更新参数
        if loss>0.001:
            '''
            loss = (x-wi*xi)**2
            d(loss)/d(wi)=-2*(x-wi*x*)*xi
            wi=wi-learn_rate(- d(loss)/d(wi))*xi=wi+learn_rate(- d(loss)/d(wi))*xi
            '''
            w1 = w1 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i-1])
            w2 = w2 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i-2])
            w3 = w3 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i-3])
            w4 = w4 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i-4])
            w5 = w5 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i - 5])
            b  = b  + learn_rate * 2 * (float(data[-i - 6]) - y)
            print('w1',w1)

            #调整learn_rate
            #learn_rate=learn_rate/sqart(i+1)
            learn_rate = learn_rate/((i+1)**0.5)

    res=w1*float(data[4])+w2*float(data[3])\
          +w3*float(data[2])+w4*float(data[1])+w5*float(data[0])+b
    print('res',res)

    return res



root=tk.Tk()
root.title('输入窗口')
center_win(root, 300, 150)
root.maxsize(600, 400)
root.minsize(250, 250)
la = tk.Label(root,text='请输入股票代码:预测股价20以内比较靠谱。例如:京东方 000725')
la.pack()
text=tk.Entry(root)
text.pack()
b1=Button(root,text='确定',command=Test)
b1.pack()
root.mainloop()
上一篇 下一篇

猜你喜欢

热点阅读