Stack vs Concat in PyTorch, TensorFlow & NumPy

最近在重新捡起以前学习过的东西,这里写一下stack与concatenate在pytorch,tensorflow和numpy中的表达方式。最开始的参考为这个博客,建议有技术的直接看原文。

Stack与Concatenate的区别

  • Concatenate joins a sequence of tensors along an existing axis.
  • Stack joins a sequence of tensors along a new axis.

How to Add or Insert an Axis into a Tensor

这里先用pytorch进行演示:

import torch
t1 = torch.tensor([1,1,1])
  • 增加一个维度
> t1.unsqueeze(dim=0)
tensor([[1, 1, 1]])
> t1.unsqueeze(dim=1)
tensor([[1],
        [1],       
        [1]])
  • 观察形状
> print(t1.shape)
> print(t1.unsqueeze(dim=0).shape)
> print(t1.unsqueeze(dim=1).shape)

torch.Size([3])  # 注意这里tensor的初始维度!
torch.Size([1, 3])
torch.Size([3, 1])

Stack vs Cat in PyTorch

import torch
t1 = torch.tensor([1,1,1])
t2 = torch.tensor([2,2,2])
t3 = torch.tensor([3,3,3])
  • 使用cat
> torch.cat((t1,t2,t3),dim=0)

tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])   # 在原有维度上操作
  • 使用stack
> torch.stack((t1,t2,t3) ,dim=0)

tensor([[1, 1, 1],
        [2, 2, 2],      
        [3, 3, 3]])     #  在新增加的维度操作
  • 使用cat来达到与stack相同的效果
> torch.cat((t1.unsqueeze(0),
            t2.unsqueeze(0),
            t3.unsqueeze(0)),
            dim=0)

tensor([[1, 1, 1],       
        [2, 2, 2],
        [3, 3, 3]])
  • 试试第二个维度
> torch.stack((t1,t2,t3),dim=1)

tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])
> torch.cat((t1.unsqueeze(1),
            t2.unsqueeze(1),
            t3.unsqueeze(1)),
            dim=1)

tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

Stack vs Concat in TensorFlow

import tensorflow as tf
t1 = tf.constant([1,1,1])
t2 = tf.constant([2,2,2])
t3 = tf.constant([3,3,3])
  • 使用concat
> tf.concat((t1,t2,t3),axis=0)

tf.Tensor: id=4, shape=(9,), dtype=int32, numpy=array([1, 1, 1, 2, 2, 2, 3, 3, 3])
  • 使用stack
> tf.stack((t1,t2,t3),axis=0)

tf.Tensor: id=6, shape=(3, 3), dtype=int32, numpy=
array([[1, 1, 1],       
        [2, 2, 2],
        [3, 3, 3]])
  • 使用concat达到与stack相同效果(注意:tensorflow中的expand_dims与Pytorch中的unsqueeze效果相同)
> tf.concat((tf.expand_dims(t1, 1),
            tf.expand_dims(t2, 1),
            tf.expand_dims(t3, 1)),
            axis=1)
            
tf.Tensor: id=15, shape=(3, 3), dtype=int32, numpy=
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
  • 试试第二维度
> tf.stack((t1,t2,t3) ,axis=1)

tf.Tensor: id=17, shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
       [1, 2, 3],
       [1, 2, 3]])
> tf.concat((tf.expand_dims(t1, 0),
            tf.expand_dims(t2, 0),
            tf.expand_dims(t3, 0)),
            axis=0)
            
tf.Tensor: id=26, shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
       [1, 2, 3],
       [1, 2, 3]])

Stack vs Concatenate in NumPy

import numpy as np
t1 = np.array([1,1,1])
t2 = np.array([2,2,2])
t3 = np.array([3,3,3])
  • 使用concatenate
> np.concatenate((t1,t2,t3),axis=0)

array([1, 1, 1, 2, 2, 2, 3, 3, 3])
  • 使用stack
> np.stack((t1,t2,t3),axis=0)

array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
  • 使用concatenate达到与stack相同效果
> np.concatenate((np.expand_dims(t1, 0),
                np.expand_dims(t2, 0),
                np.expand_dims(t3, 0)),
                axis=0)

array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
  • 第二维度
> np.stack((t1,t2,t3),axis=1)

array([[1, 2, 3],       
        [1, 2, 3],
        [1, 2, 3]])
> np.concatenate((np.expand_dims(t1, 1),
                    np.expand_dims(t2, 1),
                    np.expand_dims(t3, 1)),
                    axis=1)

array([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

完全相通,注意concatenate名字的区别

LibraryFunctionName
PyTorchcat()
TensorFlowconcat()
NumPyconcatenate()

实际例子操作

  • 将图片加入一个batch中
import torch
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.stack((t1,t2,t3),dim=0).shape

## output ##
torch.Size([3, 3, 28, 28])

这里使用stack操作,因为原图片没有batch_size的维度

  • 将batches加入一个单独的batch中
import torch
t1 = torch.zeros(1,3,28,28)
t2 = torch.zeros(1,3,28,28)
t3 = torch.zeros(1,3,28,28)
torch.cat((t1,t2,t3),dim=0).shape

## output ##
torch.Size([3, 3, 28, 28])

这里使用cat操作,因为原图片已经有batch_size的维度

  • 将图片加入一个已经存在的batch中
import torch
batch = torch.zeros(3,3,28,28)
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.cat((batch,torch.stack((t1,t2,t3),dim=0)),dim=0).shape

## output ##
torch.Size([6, 3, 28, 28])

先stack,合并三图片与一个batch中,再cat2个batch或如下操作也可达到相同目的:

import torch
batch = torch.zeros(3,3,28,28)
t1 = torch.zeros(3,28,28)
t2 = torch.zeros(3,28,28)
t3 = torch.zeros(3,28,28)
torch.cat((batch,torch.stack((t1,t2,t3),dim=0)),dim=0).shape

## output ##
torch.Size([6, 3, 28, 28])