分享

利用Tensorflow实现卡尔曼预测股票

本帖最后由 levycui 于 2019-12-10 23:18 编辑
问题导读:
1、如何使用Tensorflow实现一个Kalman预测模型?
2、如何使用Tensorflow定义KalmanFilter类?
3、如何使用Tensorflow定义correct类?
4、如何使用kalman.KalmanFilter预测数据?




前言

前几篇文章里的矩阵运算都是基于numpy实现的,这里也展示的是使用python进行矩阵运算时常用的一个库——Tensorflow。Tensorflow算是目前最火的一个三方库,在此之前雄踞榜首的三方库一直是JS。
2019-12-10_230944.jpg

本文将使用Tensorflow实现一个Kalman预测模型,用于预测股票的变化。内核仍然采用Tensorflow实现,留给用户的接口,仍然采用Numpy。K实现Kalman并不困难,只要写好初始化,预测和纠正三个接口函数,用户就能够通过Kalman进行数据预测。Kalman在惯性系统的预测和滤波中具有重要地位。

使用Tensorflow实现Kalman的类
[mw_shl_code=python,true]import tensorflow as tf
import numpy as np
class KalmanFilter(object):
      def __init__(self, x=None, A=None, P=None, B=None, H=None, Q=None):
            m = self._m = H.shape[0]
            n = self._n = x.shape[0]
            l = self._l = B.shape[1]
            self._x = tf.Variable(x, dtype=tf.float32, name="x")
            self._A = tf.constant(A, dtype=tf.float32, name="A")
            self._P = tf.Variable(P, dtype=tf.float32, name="P")
            self._B = tf.constant(B, dtype=tf.float32, name="B")
            self._Q = tf.constant(Q, dtype=tf.float32, name="Q")
            self._H = tf.constant(H, dtype=tf.float32, name="H")
            self._u = tf.placeholder(dtype=tf.float32, shape=[l, 1], name="u")
            self._z = tf.placeholder(dtype=tf.float32, shape=[m, 1], name="z")
            self._R = tf.placeholder(dtype=tf.float32, shape=[m, m], name="R")

      def predict(self):
            x = self._x
            A = self._A
            P = self._P
            B = self._B
            Q = self._Q
            u = self._u
            x_pred = x.assign(tf.matmul(A, x) + tf.matmul(B, u))
            p_pred = P.assign(tf.matmul(A, tf.matmul(P, A, transpose_b=True)) + Q)
            return x_pred, p_pred
      
   
      def correct(self):
            x = self._x
            P = self._P
            H = self._H
            z = self._z
            R = self._R
            K = tf.matmul(P, tf.matmul(tf.transpose(H), tf.matrix_inverse(tf.matmul(H, tf.matmul(P, H, transpose_b=True)) + R)))
            x_corr = x.assign(x + tf.matmul(K, z - tf.matmul(H, x)))
            P_corr = P.assign(tf.matmul((1 - tf.matmul(K, H)), P))
            return K, x_corr, P_corr
[/mw_shl_code]

读取数据并且实现它
[mw_shl_code=python,true]
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import xlrd
from plugin import kalman

rnd = np.random.RandomState(0)

'''
数据读取
'''
workbook = xlrd.open_workbook("data.xlsx")
sheet = workbook.sheet_by_name("Sheet1")

n_timesteps = 200
observations = []
x_axis = []
for i in range(1,n_timesteps+1):
      observations.append(float(sheet.cell(i,2).value))
      x_axis.append(float(i))
print(observations)
x_axis = np.array(x_axis,dtype = np.float)
observations = np.array(observations)

n = 1
m = 1
l = 1
x = np.ones([1, 1])
A = np.ones([1, 1])
B = np.zeros([1, 1])
P = np.ones([1, 1])
Q = np.array([[0.005]])
H = np.ones([1, 1])
u = np.zeros([1, 1])
R = np.array([[0.05]])
predictions = []
with tf.Session() as sess:
      kf = kalman.KalmanFilter(x=x, A=A, B=B, P=P, Q=Q, H=H)
      predict = kf.predict()
      correct = kf.correct()
      tf.global_variables_initializer().run()
      for i in range(0, n_timesteps):
            x_pred, _ = sess.run(predict, feed_dict={kf._u: u})
            predictions.append(x_pred[0, 0])
            sess.run(correct, feed_dict={kf._z:np.array([[observations]]), kf._R:R})

# 支持中文t
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.figure(figsize=(16, 6))
obs_scatter = plt.scatter(x_axis, observations, marker='x', color='b',
                         label='实际值')
position_line = plt.plot(x_axis, np.array(predictions),
                        linestyle='-', marker='o', color='r',
                        label='预测值')

plt.legend(loc='lower right')
plt.xlim(xmin=0, xmax=x_axis.max())
plt.xlabel('time')
plt.show()
[/mw_shl_code]

最终的预测效果
2019-12-10_231143.jpg
红色点为预测值,蓝色点为真实值,第0个红色点的值由用户给出,此后的任意第N个点都是将前N-1个真实点的值作为训练集,然后预测得出。在验证需要大量矩阵运算的算法时,Tensorflow会是比numpy更加方便且好用的一个库。Tensorflow基于数据流图实现,相比实现传统的复杂逻辑而言,数据流图会更好实现。矩阵运算本来就是数据的流动,采用逻辑来为矩阵运算做出规划,反而显得格格不入。

作者:cclplus
来源:https://blog.csdn.net/m0_37772174/article/details/103334216
最新经典文章,欢迎关注公众号


已有(1)人评论

跳转到指定楼层
潮水爱迟到 发表于 2020-3-10 13:30:26
大神 你的excel 数据结构能贴出来吗 用的那些数据
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

推荐上一条 /2 下一条