scatter_nd_update(ref,indices,updates,use_locking=True,name=None
)
主要参数为三个:ref是被赋值的张量,indices是具体的索引位置,是整数类型的张量,updates是要赋值的张量,注意与ref为同样类型。该函数就是将ref[indices]的值替换为updates。
举例:
对于一维张量的坐标赋值
tensor = tf.constant([0, 0, 0, 0, 0, 0, 0, 0],dtype=tf.int32)ipdb> indices = tf.constant([[3], [1], [4], [7]])ipdb> updates = tf.constant([5, 6, 7, 12],dtype=tf.int32)ipdb> sor_scatter_nd_update(tensor, indices, updates))
输出为
tf.Tensor([ 0 6 0 5 7 0 0 12], shape=(8,), dtype=int32)
对于二维张量的坐标赋值,注意indices 中每个元素为索引的[行,列]。
tensor = tf.constant([[0, 0], [1, 1], [2, 2]],dtype=tf.int32)ipdb> indices stant( [[0, 1],[1,1], [2, 0]])ipdb> updates = tf.constant([3, 6,9],dtype=tf.int32)ipdb> sor_scatter_nd_update(tensor, indices, updates))
输出为
tf.Tensor(
[[0 3][1 6][9 2]], shape=(3, 2), dtype=int32)
利用行索引赋值整行(切片赋值)
tensor
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[0, 0],[1, 1],[2, 2]])>ipdb> indices stant( [[0], [2]])ipdb> updates = tf.constant([[3,4],[5,6] ],dtype=tf.int32)ipdb> sor_scatter_nd_update(tensor, indices, updates))
输出为
tf.Tensor(
[[3 4][1 1][5 6]], shape=(3, 2), dtype=int32)
注意如果整行的赋值,updates 需要给出这一行所有的新值。如果维度不一样,则会报错。
updates = tf.constant([[3],[4] ],dtype=tf.int32)ipdb> sor_scatter_nd_update(tensor, indices, updates))
*** tensorflow.s_impl.InvalidArgumentError: The inner 1 dimensions of output.shape=[3,2] must match the inner 1 dimensions of updates.shape=[2,1] [Op:TensorScatterUpdate]
本文发布于:2024-02-02 07:10:08,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170682900942185.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |