1. 首页
  2. TensorFlow
  3. TensorFlow实战

基于 TensorFlow 的第一个神经网络

基于 TensorFlow 的第一个神经网络

代码参考《TensorFlow:实战Google深度学习框架》,本地手打,调试后复制出来,和原文会有差别。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import tensorflow as tf
from numpy.random import RandomState

batch_size=8

w1=tf.Variable(tf.random_normal([2,3],stddev=1,seed=1))
w2=tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))

x=tf.placeholder(tf.float32,shape=(None,2),name='x-input')
y_=tf.placeholder(tf.float32,shape=(None,1),name="y-input")

a=tf.matmul(x,w1)
y=tf.matmul(a,w2)

cross_entropy=-tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,1e-10,1.0)))

train_step=tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

rdm=RandomState(1)
dataset_size=128
X=rdm.rand(dataset_size,2)
print X

Y=[[int(x1+x2<1)] for (x1,x2) in X]

print Y

with tf.Session() as sess:
    init_op=tf.initialize_all_variables()
    sess.run(init_op)
    print sess.run(w1)
    print sess.run(w2)

    STEPS=5000

    for i in range(STEPS):
        start=(i*batch_size)%dataset_size
        end=min(start+batch_size,dataset_size)

        sess.run(train_step,feed_dict={x:X[start:end],y_:Y[start:end]})

        if(i%1000==0):
            total_cross_entropy=sess.run(cross_entropy,feed_dict={x:X,y_:Y})
            print("After %d training step,cross entrypy on all data is %g"%(i,total_cross_entropy))

    print sess.run(w1)
    print sess.run(w2)

基于 TensorFlow 的第一个神经网络,可以看到 w1,w2 对应的变化。

输出:

[[-0.81131822 1.48459876 0.06532937]
[-2.44270396 0.0992484 0.59122431]]
[[-0.81131822]
[ 1.48459876]
[ 0.06532937]]
After 0 training step,cross entrypy on all data is 0.0674925
After 1000 training step,cross entrypy on all data is 0.0163385
After 2000 training step,cross entrypy on all data is 0.00907547
After 3000 training step,cross entrypy on all data is 0.00714436
After 4000 training step,cross entrypy on all data is 0.00578471
[[-1.9618274 2.58235407 1.68203783]
[-3.4681716 1.06982327 2.11788988]]
[[-1.8247149 ]
[ 2.68546653]
[ 1.41819501]]

原创文章,作者:fendouai,如若转载,请注明出处:http://www.buluo360.com/2017/07/01/base-tensorflow-first-neual-network/

发表评论

电子邮件地址不会被公开。 必填项已用*标注

联系我们

 

QQ:1722332572