TensorFlow命令行参数FLAGS使用
阅读原文时间:2023年07月09日阅读:1

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf

#tensorboard --logdir="./"

命令行参数 python x.py --max_step=500

tf.app.flags.DEFINE_integer("max_step",1000,"train step number")

FLAGS = tf.app.flags.FLAGS

def linearregression():

with tf.variable\_scope("original\_data"):  
    X = tf.random\_normal(\[100,1\],mean=0.0,stddev=1.0)  
    y\_true = tf.matmul(X,\[\[0.8\]\]) + \[\[0.7\]\]

with tf.variable\_scope("linear\_model"):  
    weights = tf.Variable(initial\_value=tf.random\_normal(\[1,1\]))  
    bias = tf.Variable(initial\_value=tf.random\_normal(\[1,1\]))  
    y\_predict = tf.matmul(X,weights)+bias

with tf.variable\_scope("loss"):  
    loss = tf.reduce\_mean(tf.square(y\_predict-y\_true))

with tf.variable\_scope("optimizer"):  
    optimizer = tf.train.GradientDescentOptimizer(learning\_rate=0.01).minimize(loss)

#收集观察张量  
tf.summary.scalar("losses",loss)  
tf.summary.histogram("weight",weights)  
tf.summary.histogram("biases",bias)  
#合并收集的张量  
merge = tf.summary.merge\_all()

init = tf.global\_variables\_initializer()

saver = tf.train.Saver()  
with tf.Session() as sess:  
    sess.run(init)  
    # print(weights.eval(),bias.eval())  
    # # 模型加载  
    # saver.restore(sess,"./model/linearregression")  
    # print(weights.eval(),bias.eval())  
    filewriter = tf.summary.FileWriter("./tmp",graph=sess.graph)  
    for i in range(FLAGS.max\_step):  
        sess.run(optimizer)  
        print("loss:", sess.run(loss),i)  
        print("weight:", sess.run(weights))  
        print("bias:", sess.run(bias))  
        summary = sess.run(merge)  
        filewriter.add\_summary(summary,i)

    #checkpoint文件,模型保存  
    saver.save(sess,"./model/linearregression")

if __name__ == '__main__':
linearregression()

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器