网站首页 文章专栏 tf.metrics.accuracy计算的是正确率吗
tf.metrics.accuracy计算的是正确率吗
创建于:2018-04-03 16:00:00 更新于:2025-01-18 10:09:36 羽瀚尘 2202
python python


被坑惨了

正确率使用tf.metrics.accuracy来计算,总是发现机器学不到什么。

然后就不断加层,调参数,各种折腾。

关键我还是几个网络在一起工作。。。。

后来终于看它不顺眼,用了一个更顺眼的方式来计算正确率,就这么从坑里出来了。

正常的计算单个batch正确率的代码
python correct_prediction = tf.equal( tf.argmax(y_, 1), tf.argmax(output, 1)) accuracy = tf.reduce_mean( tf.cast(correct_prediction, tf.float32))

tf.metrics.accuracy算的是什么?

简单的讲,它计算的是整个session生存期内,所有feed_dict中的数据的正确率。

不是单个batch的正确率

验证代码

”`python
import os

supress tensorflow logging other than errors

os.environ[‘TF_CPP_MIN_LOG_LEVEL’] = ‘3’

import numpy as np
import tensorflow as tf


print(tf.version)

1.1.0


x = tf.placeholder(tf.int32, [5])
y = tf.placeholder(tf.int32, [5])
acc, acc_op = tf.metrics.accuracy(labels=x, predictions=y)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

v = sess.run([acc, acc_op], feed_dict={x: [1, 0, 0, 0, 0],
y: [1, 0, 0, 0, 1]})
print(v)
#acc 与acc op都 运行,得到的都是正确率,不过acc是用来更新的,当前的feed_dict不计入
#另外,acc_op会维护历史数据,acc只是从历史数据中获得结果
#如果不运行acc_op, 历史数据不会更新
#所以acc输出是0, acc_op是0.8

[0.0, 0.8]


v = sess.run(acc)
print(v)
#这里单独运行acc,对历史数据做统计,输出0.8

0.8


v = sess.run([acc_op], feed_dict={x: [1, 0, 0, 0, 0],
y: [0, 1, 1, 1, 1]})
print(v)
#这里单独运行acc_op, 历史正确率是0.4

0.4



v = sess.run([acc], feed_dict={x: [0, 1, 1, 1, 1],
y: [0, 1, 1, 1, 1]})
print(v)
#这里单独运行acc,输出历史正确率0.4,历史数据得不到更新
#实际的历史正确率是0.6

0.4


sess.close()

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
v = sess.run([acc, acc_op], feed_dict={x: [1, 0, 0, 0, 0],
y: [1, 0, 0, 0, 1]})
print(v)
#重新开始一个session,重新计算正确率

[0.0, 0.8]

”`

Reference:
1. https://github.com/tensorflow/tensorflow/issues/15115
2. https://github.com/tensorflow/tensorflow/issues/9498
3. https://stackoverflow.com/questions/46409626/how-to-properly-use-tf-metrics-accuracy