TensorFlow-Slim 简介+Demo
阅读原文时间:2023年07月11日阅读:3

github介绍:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim

基于slim实现的yolo-v3(测试可用):https://github.com/mystic123/tensorflow-yolo-v3

  • TF-Slim是一个轻量级tensorflow库。
  • 它可以使复杂模型的定义、训练、评估测试更简单。
  • 它的组件,可以与tensorflow的其他库(如tf.contrib.learn)混合使用。
  • 它允许用户更紧凑地定义模型,通过消除样板代码(boilerplate code)。

import tensorflow as tf
from tensorflow.contrib.layers.python.layers import layers as layers_lib
from tensorflow.contrib import layers
import tensorflow.contrib.slim as slim
from keras.datasets import mnist
import numpy as np
import math

print("Hello slim.")
pixel_depth = 256
learning_rate = 0.01
checkpoint_dir = "./ckpts/"
log_dir = "./logs/"
batch_size = 1000

Get the data, mnist.npz is in ~/.keras/datasets/mnist.npz

print("Loading the MNIST data in ~/.keras/datasets/mnist.npz")
(train_data, train_labels), (test_data, test_labels) = mnist.load_data()
train_data = train_data .reshape(-1,28,28,1).astype(np.float32)
train_labels = train_labels.reshape(-1) .astype(np.int64)
test_data = test_data .reshape(-1,28,28,1).astype(np.float32)
test_labels = test_labels.reshape(-1) .astype(np.int64)

train_data = 2.0*train_data/pixel_depth - 1.0
test_data = 2.0*test_data /pixel_depth - 1.0

train_data = train_data[0:10000]
train_labels = train_labels[0:10000]

print("train data shape:", train_data.shape)
print("test data shape:", test_data.shape)

slim.nets.vgg.vgg_16

def MyModel(inputs, num_classes=10, is_training=True, dropout_keep_prob=0.5, spatial_squeeze=False, scope='MyModel'):
with tf.variable_scope(scope):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
weights_regularizer=slim.l2_regularizer(0.0005)):

        net = slim.convolution2d(inputs, 8, \[3, 3\], 1, padding='SAME', scope='conv1')  
        net = layers\_lib.max\_pool2d(net, \[2, 2\], scope='pool1')

        net = slim.convolution2d(net, 8, \[5, 5\], 1, padding='SAME', scope='conv2')  
        net = layers\_lib.max\_pool2d(net, \[2, 2\], scope='pool2')

        net = slim.flatten(net, scope='flatten1')

        net = slim.fully\_connected(net, num\_classes\*num\_classes, activation\_fn=None, scope='fc1')  
        net = slim.fully\_connected(net, num\_classes, activation\_fn=None, scope='fc2')

return net

def train_data_batch(batch_size):
if not hasattr(train_data_batch, 'train_index'):
train_data_batch.train_index = 0
data_size = train_labels.shape[0]
idx = np.arange(train_data_batch.train_index, train_data_batch.train_index+batch_size, 1)
idx = idx % data_size
train_data_batch.train_index = (train_data_batch.train_index + batch_size) % data_size
yield train_data[idx]

logits = MyModel(train_data)
loss = slim.losses.sparse_softmax_cross_entropy(logits, train_labels)

total_loss = slim.losses.get_total_loss(add_regularization_losses=False)

optimizer = tf.train.GradientDescentOptimizer(learning_rate)

train_op = slim.learning.create_train_op(total_loss, optimizer)
slim.learning.train(train_op,
checkpoint_dir,
number_of_steps=100,
save_summaries_secs=5,
save_interval_secs=10)

print("See you, slim.")

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器

你可能感兴趣的文章