sklearn回归分析股票-中信银行

2020-02-26  本文已影响0人  Moon100

# -*- coding: utf-8 -*-

import tushare as ts

import pandas as pd

from sklearn import linear_model

import matplotlib.pyplot as plt

import numpy as np

#下载数据

data1=ts.get_hist_data('601998')

#data1 = pd.DataFrame(index='date',columns=['date', 'close','ma10']) #

#data1=data1[:300] #numpy筛选

data1=data1.iloc[0:300]  #pandas筛选

data1= pd.DataFrame(data1)

data1=data1.sort_index(ascending=True, axis=0) #排序

#data1.reset_index(name='date')

#data1=data1.as_matrix()

#------------------------

#编辑数据

data1.to_csv('/home/abc/program/sklearn/a601998.csv')

#x1=[[0, 0], [1, 1], [2, 2]]

#x11=[[1.2, 1.3], [1.2, 2.2], [2.5, 3]]

y1=data1['close']

#y1=[0, 1, 2]

x1=data1['ma10']

x0=range(len(y1))

#xk=('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'g')

xk=data1.index #获得dataframe的索引列

x1=np.array(x1).reshape(-1,1)

y1=np.array(y1).reshape(-1,1)

#----------------------

#回归分析

reg=linear_model.LinearRegression()

reg.fit (x1,y1)

quan=reg.coef_

pian=reg.intercept_

yuce=x1*quan+pian

#print(type(yuce))

#--------------------

#转置

#x1=x1.T

#y1=y1.T

#numpy数组转为dataframe

x1=pd.DataFrame(x1.reshape(-1,1))

y1=pd.DataFrame(y1.reshape(-1,1))

yuce=pd.DataFrame(yuce.reshape(-1,1))

x1.columns=['a']

y1.columns=['b']

yuce.columns=['yuce']

#print(x1['a'].head())

#print(y1['b'].head())

#将datafram转为series类型

x1=pd.Series(x1['a'])

y1=pd.Series(y1['b'])

yuce=pd.Series(yuce['yuce'])

x10=data1['close']

y10=data1['ma10']

print(type(x1))

print(type(y1))

print(type(yuce))

print(type(x10))

print(type(y10))

#绘图

plt.figure() #figsize=(6, 18.5)

plt.grid()  # 网格

plt.plot(x0,x1, c='g', linewidth=2, label='ma10')  # 折线图,绿色

plt.plot(x0,y1, c='b', linewidth=1.5, label='close')  # 折线图,蓝色

plt.plot(x0,yuce, c='tan', linewidth=1.5, label='yuce')  # 折线图,灰色

#plt.xticks(x0, xk,color='blue',rotation=60)  # 设置x轴刻度

plt.ylabel('price') #设置坐标轴名

# plt.axis([0,5,0,5]) #设置横纵坐标的范围

plt.xlim(0, len(y1))  # 设置x轴的范围

plt.ylim(5, 7)  # 设置y轴的范围

print('-----------------')

print(y1)

print('-----------------')

print(x1)

#

#plt.scatter(x0,x1, marker='o', c='r')  # 散点

#plt.scatter(x0,y1, marker='s', c='r')  # 散点

#

panduan1= yuce >=x1

panduan2= yuce<x1

panduan3= x1>=y1

panduan4= x1<y1

plt.fill_between(x0, x1, y1,where=panduan3,facecolor='yellow', alpha=0.9) #'olive','black'

plt.fill_between(x0, x1, y1,where=panduan4, facecolor='blue', alpha=0.9)

plt.fill_between(x0, 5, 7,where=panduan1,facecolor='green', alpha=0.3)

plt.fill_between(x0, 5, 7,where=panduan2, facecolor='red', alpha=0.3)

#plt.set_title('fill_betweenx where')

plt.legend()  # 图例

plt.title('matplotlab')  # 设置标题

plt.savefig('main',dpi=600) #保存图片到当前目录

plt.show() #显示图形

上一篇下一篇

猜你喜欢

热点阅读