tensorflow-可视化

导入包

import numpy as np
import os
import tensorflow as tf
import matplotlib.pyplot as plt

设置生成的图像尺寸和去除警告

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
plt.rcParams["figure.figsize"] = (14, 8) # 生成的图像尺寸

随机生成一个线性的数据

n_observations = 100
xs = np.linspace(-3, 3, n_observations) #生成-3到3的n为100等差数列
ys = 0.8*xs + 0.1 + np.random.uniform(-0.5, 0.5, n_observations)
plt.scatter(xs, ys) #画图
plt.show() #画图

准备placeholder

X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')

初始化参数/权重

W = tf.Variable(tf.random_normal([1]), name='weight')
tf.summary.histogram('weight', W) #画图
b = tf.Variable(tf.random_normal([1]), name='bias')
tf.summary.histogram('bias', b)#画图

计算预测结果

Y_pred = tf.add(tf.multiply(X, W), b)

计算损失值

loss = tf.square(Y - Y_pred, name='loss')  #tf.square:平方

tf.summary.scalar('loss', tf.reshape(loss, []))#画图

初始化optimizer

learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

指定迭代次数,并在session里执行graph

n_samples = xs.shape[0]
init = tf.global_variables_initializer()
with tf.Session() as sess:
# 记得初始化所有变量
sess.run(init)
merged = tf.summary.merge_all()#画图
log_writer = tf.summary.FileWriter("./logs/linear_regression", sess.graph)

# 训练模型
for i in range(50):
total_loss = 0
for x, y in zip(xs, ys):
# 通过feed_dic把数据灌进去
_, loss_value, merged_summary = sess.run([optimizer, loss, merged], feed_dict={X: x, Y: y})
total_loss += loss_value

if i % 5 == 0:
print('Epoch {0}: {1}'.format(i, total_loss / n_samples))
log_writer.add_summary(merged_summary, i)#画图

# 关闭writer
log_writer.close()#画图

# 取出w和b的值
W, b = sess.run([W, b])

print(W, b)
print("W:"+str(W[0]))
print("b:"+str(b[0]))

画出线性回归线

plt.plot(xs, ys, 'bo', label='Real data')
plt.plot(xs, xs * W + b, 'r', label='Predicted data')
plt.legend()
plt.show()

Tensorboard查看图形数据

tensorboard --logdir path/to/logs(你保存文件所在位置)

如:(log_writer = tf.summary.FileWriter(“./logs/linear_regression”, sess.graph)保存的地址):

tensorboard —logdir ./logs/linear_regression

输出:

TensorBoard x.x.x at http://(你的用户名):6006 (Press CTRL+C to quit)

然后打开网页:http://localhost:6006

------ 本文结束 🎉🎉 谢谢观看 ------
0%