关于循环变量的tf.gradient

我正在尝试计算计算w.r.t的导数。在tf.while_loop中循环变量,显然,这是不可能的。玩具示例:

x = tf.Variable(10)
fx = tf.square(x) #possibly very complicated function

def cond(x):
    return tf.not_equal(x, 0)

def body(x):
    return  tf.cond(tf.gradients(fx, x)[0] > 0, lambda: x-1, lambda: x+1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    o = sess.run(tf.while_loop(cond, body, [x]))
    print(o)

运行上面的代码会产生以下结果:

TypeError: unorderable types: NoneType() > int()

原因(据我所知)是循环中的变量x不是我的全局函数fx中的一个节点。

如何解决此问题?一个想法是将fx添加到循环变量中,但我找不到一种方法来维护它,除非在每个循环中完全重建计算图。

谢谢!

更新:以下更改可解决此问题:

更改函数:

def f(x):
    return tf.square(x) #possibly very complicated function

并改变主体:

def body(x):
    return  tf.cond(tf.gradients(f(x), x)[0] > 0, lambda: x-1, lambda: x+1)

但是这种解决方案在每次迭代中重建图,这可能是低效的。

转载请注明出处:http://www.xiangbinbaiyi.com/article/20230512/1792941.html