定义
scatter_(dim, index, src, *, reduce=None) -> Tensor
pytorch官网说这个函数的作用是从src中把index指定的位置把数据写入到self里面,然后给了一个公式:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
这个公式我也是一脸懵,但是我们可以把他降维到二维表格上,即:
self[ index[i][j] ][j] = src[i][j] # if dim == 0
把src从 i 行 移动到了 index[i][j] 行
self[i][ index[i][j] ] = src[i][j] # if dim == 1
把src 从 j 列移动到了 index[i][j] 列
对此,个人认为比较直观的理解:
dim=0,就是把本行这个data放到本列的哪行(上下移动)
dim=1,就是把本列这个data放到本行的哪列(左右移动)
所以,index数组其实是一个位置变化的映射表。
例子1
给定src是一个顺序数组,我们可以更清楚看到这一变化过程。
>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
当我们指定 dim=0,就是把每一行data放到上下移动位置,比如我们给一个例子
>>> index = torch.tensor([ [0, 0, 0], [1, 1, 1], [2, 2, 0] ])
>>> src.scatter(dim = 0, index=index, src = src)
tensor([[1, 2, 9],[4, 5, 6],[7, 8, 9]])
可以看到,scatter之后只有 src[0][2] 发生了变化,为什么呢?
前面提到了index数组其实是一个位置变化的映射表, dim=0 时候是把src从 i 行 移动到了 index[i][j] 行(上下移动), 这里的index表 0行所有的元素都移动到了0行对应位置, 1行所有的的元素都移动到了1行对应位置, 只有2行最后一个元素移动到了0行,造成的结果就是src只有最后一个元素移动到了0行的对应位置(从src[2][2]移动到了src[0][2])
例子2
下面我们再试试dim = 1 时候 把src 从 j 列移动到了 index[i][j] 列
给定src是一个顺序数组,我们可以更清楚看到这一变化过程。
>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index给定如下
>>> index = torch.tensor([ [0, 1, 2], [0, 1, 2], [0, 1, 0] ])
>>> index
tensor([[0, 1, 2],[0, 1, 2],[0, 1, 0]])>>> src.scatter(dim = 1, index=index, src = src)
tensor([[1, 2, 3],[4, 5, 6],[9, 8, 9]])
可以看到,这里src也只有一个位置发生了变化,为什么呢?
前面提到了index数组其实是一个位置变化的映射表, dim=1 时候是把src从 i 列 移动到了 index[i][j] 列 (左右移动), 这里的index表 0行 012 列对应的元素都移动到了0行 012列 对应位置(相当于没动), 1行 012 列对应的元素都移动到了1行 012列 对应位置(相当于没动), 只有2行最后一个元素移动到了0列,造成的结果就是src只有最后一个元素移动到了2行0列的位置(从src[2][2]移动到了src[2][0] )
意义
那么这种映射这么复杂,它的意义在哪里呢?
答:一般scatter用于生成onehot向量
这里还是举个例子
我们还是拿之前的src数组
>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
我们要如何理解它呢?我们可以认为它是三只股票在昨天、今天、明天的股票价格,昨天三只股票的价格分别为1,4,7,今天三只股票的价格分别为2,5,8, 明天三只股票的价格分别为3,6,9。
现在我们要训练一个预测后天股票价格的神经网络,我们给模型的输入应该是昨天三只股票的价格、今天三只股票的价格、明天三只股票的价格,即1,4,7,2,5,8,3,6,9。同时,我们要把每个数字转化为一个onehot的向量,这样的结果是我们期望的。
所以,我们要做的事情是把src转换为一个 3*3 的矩阵,矩阵中每个元素是一个能表示0-9的10维one-hot向量。
拿一段常用的onehot生成代码说事。
def one_hot(x, n_class, dtype=torch.float32):# X shape: (batch, 1), output shape: (batch, n_class)x = x.long()res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)res.scatter_(1, x.view(-1, 1), 1)return res# X shape: batch_size, prices_list
def to_onehot(X, n_class):# 返回结果 shape: prices_list, batch_size, onehot_size 三维return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]
先不谈代码含义,输出结果如下
>>> to_onehot(src, 10)
[tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]),tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]), tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])]
这个结果基本符合了我们的期望,那么这个是如何做到的呢?
# X shape: batch_size, prices_list
def to_onehot(X, n_class):# 返回结果 shape: prices_list, batch_size, onehot_size 三维return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]
首先,src按照昨天,今天,明天的维度,被切分为了三个列向量 [1,4,7]、[2,5,8]、 [3,6,9] 。这三个列向量对应了我们的输出,one_hot给定一个列向量,可以转换为一个one-hot列向量组。
def one_hot(x, n_class, dtype=torch.float32):# X shape: (batch, 1), output shape: (batch, n_class)x = x.long()res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)res.scatter_(1, x.view(-1, 1), 1)return res
为了简单,我们举一个例子
>>> one_hot(torch.tensor([1,2,3]), 4)
tensor([[0., 1., 0., 0.],[0., 0., 1., 0.],[0., 0., 0., 1.]])>>> torch.tensor([1,2,3]).view(-1,1)
tensor([[1],[2],[3]])
可以看到,res是一个全0矩阵,scatter操作在dim=1时,是一个左右移动的位置映射表,这里的res是一个 3 * 4 的矩阵,src是一个数字,可以认为是跟res同样大小的全1矩阵,但是index是一个 3*1 的矩阵,也就是这个位置映射表可以认为是一个3行1列的映射表,即 全1矩阵的0 行 0 列映射到res的 0 行 1列,全1矩阵的1行0列映射到res的1行2列,全1矩阵的2行0列映射到res的2行3列,其他保持不变(其他都是0),dim=1这种操作就是制造了one-hot向量