这样需要进行的操作就是将[5,4,3]的tensor按照第3维抽取出来,成为3个[5,4]的tensor 然后,将这3个tensor按照第1维或者第2维来进行拼接。
首先,令B = tf.split(A,3,axis=2),此时B的shape是[3,5,4,1],B是一个tensor列表,里面有3个tensor张量,每一个张量的shape都是[5,4,1]的。 然后,令C = tf.squeeze(B,3)将B中的第4个维度删除,此时C的shape为[3,5,4]如何将一个[5,4,3]的tensor A转换成[3,5,4]的tensor呢?
首先,令D=tf.unstack(C,axis=0),此时D的shape是[3,5,4]。这时有人会问为什么要做这一步?D和C有什么区别,不都是[3,5,4]的tensor吗? 其实不是,我们输出一下结构看一看,C的结构如下:然后如何将[3,5,4]的tensor转换成[5,4x3]的tensor呢?
D的结构如下:
看出差别了吧,C是一个3维的张量,而D是一个2维的张量列表。 假设E = tf.concat(C,1),F = tf.concat(D,1),我们看看E和F分别是什么?
E:
F:
看出差别了吧,由于C是一个张量,所以concat还是它自己。 而D是一个张量列表,concat的输入应该是多个tensor才能进行concat,此时就对D的多个张量按照第2维进行了连接,即得到了我们想要的结果。
加油啊💪能一直坚持下来,真的很棒