TensorFlow中tf.slice,tf.split,tf.concat,tf.stack等张量操作函数详解

昨天尝试了用reshape操作numpy矩阵,但是发现reshape对矩阵的操作并不能满足我的需求,今天继续尝试

我的目标是将一个5x4x3的tensor张量转换为4x15的张量和20x3的张量,具体转换效果如下图所示。

由于操作对象是一个tensor张量,所以不能用numpy库中的函数,必须使用tensorflow框架中提供的方法,因此下文将对几个tf函数进行尝试。

1.tensor张量的打印

由于tensorflow是张量流,它的逻辑是先搭建网络结构,网络结构搭建好后,开启一个session,此时才算是将网络结构实例化。

只有在网络实例化之后,才会将数据送入网络结构,这时才能查看网络结构中的数据。

举一个例子,我个人感觉tensorflow就是在修一个水管一样,首先给你一个入口(输入),一个出口(输出),利用不同形状的水管(不同的网络单元,比如cnn中的卷积层,池化层等)搭建出一个能够将入口和出口连通的线路(满足输入输出的网络结构),搭建时水管(网络结构)中是没有水(数据)的,只有搭建好以后,打开水龙头(初始化,sess.run)才会有水(数据)流经水管(网络结构)。

所以定义的tensor结构直接打印时,输出的只是结构,而并不是数据。比如我定义下面一个矩阵

如果sl=tf.slice(a,[1,0,0],[2,2,3]),我们直接打印sl的话,出来的结果是下图

这个就是tensor的结构,而不是tensor中的数据。

因为此时还没有将数据送入结构,我们是看不到数据的,那么如何打印呢?

我们初始化一个sess=tf.Session(),然后print sess.run(sl)即可看到结果如下

就是[2,2,3]的那个tensor对不对?这样我们就打印出了tensor中的值,后续操作将基于本知识点进行可视化。

2.tf.slice(input,begin,size)

这个函数是对input数据进行切片操作,begin和size可以接受列表类型的数据,但是要保证begin+size中的每一维都小于tensor张量该维度的大小。

第一块我们已经尝试了该函数,该函数的逻辑就是从begin开始,抽取size大小的矩阵。具体如下图所示。

3.tf.split(input,num_or_size_split,axis=0,num=None)

这个函数,是对输入进行分割,num_or_size_split可以是一个数字,就是按照axis等分为几个矩阵,也可以是个列表,列表的和应该等于该维度的大小。

num这个参数没搞懂,因为第二个参数已经把形状固定了,这个num不知有啥用。具体转换逻辑如下图

4.tf.concat(input,axis)

这个函数的输入是多个tensor,并将多个tensor按照axis拼接成一个tensor。具体转换逻辑如下图。

5.tf.stack(input,axis=0)

这个函数是将多个tensor按照axis组合为一个tensor,具体转换逻辑如下图

Stack和concat的操作是有区别的,但是我又表达不出来这种区别。有点像一个二维列表的append和+的区别。

假设A = [[1],[2]],此时A.append([3])得到的是[[1],[2],[3]],而A+[3]得到的是[[1],[2],3]。

6.tf.unstack(input,num=None,axis=0)

这个函数与stack函数是相反的作用,所以图示就不画了。

经过上述6种函数的学习与操作,终于可以解决我的问题了,那就是利用tf.concat函数即可。

最后还有一个小技巧,就是tensor也可以像列表一样操作的

用最开始的a举例子,假设us=tf.unstack(a,axis=0),此时us如下图

这样的情况下,我们可以定义co=tf.concat([t for t in us],0),输出co可得

即得到了我们需要的张量了。

点赞

发表评论

[2;3Rer>