'''
Created on 2017年7月23日
@author: weizhen
'''
#导入库
from __future__ import division,print_function,absolute_import
import tflearn
import speech_data
import tensorflow as tf
#定义参数
#learning rate是在更新权重的时候用,太高可用很快
#但是loss大,太低较准但是很慢
learning_rate=0.0001
training_iters=300000#STEPS
batch_size=64
width=20 #mfcc features
height=80 #(max) length of utterance
classes = 10 #digits
#用speech_data.mfcc_batch_generator获取语音数据并处理成批次,
#然后创建training和testing数据
batch=word_batch=speech_data.mfcc_batch_generator(batch_size)
X,Y=next(batch)
trainX,trainY=X,Y
testX,testY=X,Y #overfit for now
#4.建立模型
#speech recognition 是个many to many的问题
#所以用Recurrent NN
#通常的RNN,它的输出结果是受整个网络的影响的
#而LSTM比RNN好的地方是,它能记住并且控制影响的点,
#所以这里我们用LSTM
#每一层到底需要多少个神经元是没有规定的,太少了的话预测效果不好
#太多了会overfitting,这里普遍取128
#为了减轻过拟合的影响,我们用dropout,它可以随机地关闭一些神经元,
#这样网络就被迫选择其他路径,进而生成想对generalized模型
#接下来建立一个fully connected的层
#它可以使前一层的所有节点都连接过来,输出10类
#因为数字是0-9,激活函数用softmax,它可以把数字变换成概率
#最后用个regression层来输出唯一的类别,用adam优化器来使
#cross entropy损失达到最小
#Network building
net=tflearn.input_data([None,width,height])
net=tflearn.lstm(net,128,dropout=0.8)
net=tflearn.fully_connected(net,classes,activation='softmax')
net=tflearn.regression(net,optimizer='adam',learning_rate=learning_rate,loss='categorical_crossentropy')
#5.训练模型并预测
#然后用tflearn.DNN函数来初始化一下模型,接下来就可以训练并预测,最好再保存训练好的模型
#Traing
col=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for x in col:
tf.add_to_collection(tf.GraphKeys.VARIABLES,x)
model=tflearn.DNN(net,tensorboard_verbose=0)
while 1: #training_iters
model.fit(trainX, trainY, n_epoch=10, validation_set=(testX,testY), show_metric=True, batch_size=batch_size)
_y=model.predict(X)
model.save("tflearn.lstm.model")
print(_y)
下面是训练的结果
Training Step: 3097 | total loss: [1m[32m1.51596[0m[0m | time: 1.059s
[2K
Training Step: 3098 | total loss: [1m[32m1.64602[0m[0m | time: 1.050s
[2K
Training Step: 3099 | total loss: [1m[32m1.54328[0m[0m | time: 1.052s
[2K
Training Step: 3100 | total loss: [1m[32m1.65763[0m[0m | time: 1.044s
[2K
Run id: E1W1VX
Training samples: 64
Training Step: 3101 | total loss: [1m[32m1.56009[0m[0m | time: 1.328s
[2K
Training Step: 3102 | total loss: [1m[32m1.68916[0m[0m | time: 1.034s
[2K
Training Step: 3103 | total loss: [1m[32m1.58796[0m[0m | time: 1.044s
[2K
Training Step: 3104 | total loss: [1m[32m1.49236[0m[0m | time: 1.055s
[2K
Training Step: 3105 | total loss: [1m[32m1.60916[0m[0m | time: 1.028s
[2K
Training Step: 3106 | total loss: [1m[32m1.51083[0m[0m | time: 1.049s
[2K
Training Step: 3107 | total loss: [1m[32m1.63413[0m[0m | time: 1.066s
[2K
Training Step: 3108 | total loss: [1m[32m1.74167[0m[0m | time: 1.042s
[2K
Training Step: 3109 | total loss: [1m[32m1.63324[0m[0m | time: 1.051s
[2K
Training Step: 3110 | total loss: [1m[32m1.75479[0m[0m | time: 1.042s
[2K
Run id: 93CFSE
Training samples: 64
Training Step: 3111 | total loss: [1m[32m1.64290[0m[0m | time: 1.320s
[2K
Training Step: 3112 | total loss: [1m[32m1.76515[0m[0m | time: 1.029s
[2K
Training Step: 3113 | total loss: [1m[32m1.65166[0m[0m | time: 1.050s
[2K
Training Step: 3114 | total loss: [1m[32m1.76346[0m[0m | time: 1.062s
[2K
Training Step: 3115 | total loss: [1m[32m1.65255[0m[0m | time: 1.042s
[2K
Training Step: 3116 | total loss: [1m[32m1.55663[0m[0m | time: 1.042s
[2K
Training Step: 3117 | total loss: [1m[32m1.67928[0m[0m | time: 1.051s
[2K
Training Step: 3118 | total loss: [1m[32m1.78375[0m[0m | time: 1.043s
[2K
Training Step: 3119 | total loss: [1m[32m1.67364[0m[0m | time: 1.041s
[2K
Training Step: 3120 | total loss: [1m[32m1.79457[0m[0m | time: 1.044s
[2K
Run id: YE812Z
Training samples: 64
Training Step: 3121 | total loss: [1m[32m1.68830[0m[0m | time: 1.351s
[2K
Training Step: 3122 | total loss: [1m[32m1.79857[0m[0m | time: 1.022s
[2K
Training Step: 3123 | total loss: [1m[32m1.68557[0m[0m | time: 1.071s
[2K
Training Step: 3124 | total loss: [1m[32m1.58528[0m[0m | time: 1.042s
[2K
Training Step: 3125 | total loss: [1m[32m1.49228[0m[0m | time: 1.042s
[2K
Training Step: 3126 | total loss: [1m[32m1.41012[0m[0m | time: 1.052s
[2K
Training Step: 3127 | total loss: [1m[32m1.55866[0m[0m | time: 1.023s
[2K
Training Step: 3128 | total loss: [1m[32m1.46943[0m[0m | time: 1.044s
[2K
Training Step: 3129 | total loss: [1m[32m1.39050[0m[0m | time: 1.042s
[2K
Training Step: 3130 | total loss: [1m[32m1.54006[0m[0m | time: 1.043s
[2K
Run id: YGRXY5
Training samples: 64
Training Step: 3131 | total loss: [1m[32m1.45402[0m[0m | time: 1.336s
[2K
Training Step: 3132 | total loss: [1m[32m1.59202[0m[0m | time: 1.029s
[2K
Training Step: 3133 | total loss: [1m[32m1.50035[0m[0m | time: 1.070s
[2K
Training Step: 3134 | total loss: [1m[32m1.41417[0m[0m | time: 1.042s
[2K
Training Step: 3135 | total loss: [1m[32m1.34060[0m[0m | time: 1.037s
[2K
Training Step: 3136 | total loss: [1m[32m1.47476[0m[0m | time: 1.039s
[2K
Training Step: 3137 | total loss: [1m[32m1.38535[0m[0m | time: 1.053s
[2K
Training Step: 3138 | total loss: [1m[32m1.51673[0m[0m | time: 1.063s
[2K
Training Step: 3139 | total loss: [1m[32m1.42892[0m[0m | time: 1.042s
[2K
Training Step: 3140 | total loss: [1m[32m1.58217[0m[0m | time: 1.052s
[2K
这里边有一个死循环,具体怎么回事我也不太清楚。
下边是可视化训练,展示训练的图像
手机扫一扫
移动阅读更方便
你可能感兴趣的文章