我正在尝试计算计算w.r.t的导数。在tf.while_loop中循环变量,显然,这是不可能的。玩具示例:
x = tf.Variable(10)
fx = tf.square(x) #possibly very complicated function
def cond(x):
return tf.not_equal(x, 0)
def body(x):
return tf.cond(tf.gradients(fx, x)[0] > 0, lambda: x-1, lambda: x+1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
o = sess.run(tf.while_loop(cond, body, [x]))
print(o)
运行上面的代码会产生以下结果:
TypeError: unorderable types: NoneType() > int()
原因(据我所知)是循环中的变量x不是我的全局函数fx中的一个节点。
如何解决此问题?一个想法是将fx添加到循环变量中,但我找不到一种方法来维护它,除非在每个循环中完全重建计算图。
谢谢!
更新:以下更改可解决此问题:
更改函数:
def f(x):
return tf.square(x) #possibly very complicated function
并改变主体:
def body(x):
return tf.cond(tf.gradients(f(x), x)[0] > 0, lambda: x-1, lambda: x+1)
但是这种解决方案在每次迭代中重建图,这可能是低效的。
转载请注明出处:http://www.xiangbinbaiyi.com/article/20230512/1792941.html