您的位置:首页 > 游戏 > 手游 > 形象化理解pytorch中的tensor.scatter操作

形象化理解pytorch中的tensor.scatter操作

2024/12/25 21:43:40 来源:https://blog.csdn.net/qq_33882435/article/details/142052824  浏览:    关键词:形象化理解pytorch中的tensor.scatter操作

定义

        scatter_(dim, index, src, *, reduce=None) -> Tensor

pytorch官网说这个函数的作用是从src中把index指定的位置把数据写入到self里面,然后给了一个公式:           

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

这个公式我也是一脸懵,但是我们可以把他降维到二维表格上,即:

            self[ index[i][j] ][j] = src[i][j]  # if dim == 0

把src从 i 行 移动到了 index[i][j] 行

            self[i][ index[i][j] ] = src[i][j]  # if dim == 1

把src 从 j 列移动到了 index[i][j] 列

对此,个人认为比较直观的理解:
        dim=0,就是把本行这个data放到本列的哪行(上下移动)
        dim=1,就是把本列这个data放到本行的哪列(左右移动)

所以,index数组其实是一个位置变化的映射表

例子1

给定src是一个顺序数组,我们可以更清楚看到这一变化过程。

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])

当我们指定 dim=0,就是把每一行data放到上下移动位置,比如我们给一个例子

>>> index = torch.tensor([ [0, 0, 0], [1, 1, 1], [2, 2, 0] ]) 
>>> src.scatter(dim = 0, index=index, src = src)
tensor([[1, 2, 9],[4, 5, 6],[7, 8, 9]])

可以看到,scatter之后只有 src[0][2] 发生了变化,为什么呢?

 前面提到了index数组其实是一个位置变化的映射表,  dim=0 时候是把src从 i 行 移动到了 index[i][j] 行(上下移动), 这里的index表 0行所有的元素都移动到了0行对应位置, 1行所有的的元素都移动到了1行对应位置, 只有2行最后一个元素移动到了0行,造成的结果就是src只有最后一个元素移动到了0行的对应位置(从src[2][2]移动到了src[0][2])

 例子2

下面我们再试试dim = 1 时候 把src 从 j 列移动到了 index[i][j] 列

给定src是一个顺序数组,我们可以更清楚看到这一变化过程。

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])

index给定如下

>>> index = torch.tensor([ [0, 1, 2], [0, 1, 2], [0, 1, 0] ])
>>> index
tensor([[0, 1, 2],[0, 1, 2],[0, 1, 0]])>>> src.scatter(dim = 1, index=index, src = src) 
tensor([[1, 2, 3],[4, 5, 6],[9, 8, 9]])

可以看到,这里src也只有一个位置发生了变化,为什么呢?

 前面提到了index数组其实是一个位置变化的映射表,  dim=1 时候是把src从 i 列 移动到了 index[i][j] 列 (左右移动), 这里的index表 0行 012 列对应的元素都移动到了0行 012列 对应位置(相当于没动), 1行 012 列对应的元素都移动到了1行 012列 对应位置(相当于没动), 只有2行最后一个元素移动到了0列,造成的结果就是src只有最后一个元素移动到了2行0列的位置(从src[2][2]移动到了src[2][0] )

意义

那么这种映射这么复杂,它的意义在哪里呢? 

答:一般scatter用于生成onehot向量

这里还是举个例子

我们还是拿之前的src数组

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])

我们要如何理解它呢?我们可以认为它是三只股票在昨天、今天、明天的股票价格,昨天三只股票的价格分别为1,4,7,今天三只股票的价格分别为2,5,8, 明天三只股票的价格分别为3,6,9。

现在我们要训练一个预测后天股票价格的神经网络,我们给模型的输入应该是昨天三只股票的价格、今天三只股票的价格、明天三只股票的价格,即1,4,7,2,5,8,3,6,9。同时,我们要把每个数字转化为一个onehot的向量,这样的结果是我们期望的。

所以,我们要做的事情是把src转换为一个 3*3 的矩阵,矩阵中每个元素是一个能表示0-9的10维one-hot向量。

拿一段常用的onehot生成代码说事。


def one_hot(x, n_class, dtype=torch.float32):# X shape: (batch, 1), output shape: (batch, n_class)x = x.long()res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)res.scatter_(1, x.view(-1, 1), 1)return res# X shape: batch_size, prices_list
def to_onehot(X, n_class):# 返回结果 shape: prices_list, batch_size, onehot_size 三维return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]

先不谈代码含义,输出结果如下 

>>> to_onehot(src, 10) 
[tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]),tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]), tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])]

这个结果基本符合了我们的期望,那么这个是如何做到的呢? 


# X shape: batch_size, prices_list
def to_onehot(X, n_class):# 返回结果 shape: prices_list, batch_size, onehot_size 三维return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]

首先,src按照昨天,今天,明天的维度,被切分为了三个列向量 [1,4,7]、[2,5,8]、 [3,6,9] 。这三个列向量对应了我们的输出,one_hot给定一个列向量,可以转换为一个one-hot列向量组。

def one_hot(x, n_class, dtype=torch.float32):# X shape: (batch, 1), output shape: (batch, n_class)x = x.long()res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)res.scatter_(1, x.view(-1, 1), 1)return res

为了简单,我们举一个例子


>>> one_hot(torch.tensor([1,2,3]), 4) 
tensor([[0., 1., 0., 0.],[0., 0., 1., 0.],[0., 0., 0., 1.]])>>> torch.tensor([1,2,3]).view(-1,1)  
tensor([[1],[2],[3]])

 可以看到,res是一个全0矩阵,scatter操作在dim=1时,是一个左右移动的位置映射表,这里的res是一个 3 * 4 的矩阵,src是一个数字,可以认为是跟res同样大小的全1矩阵,但是index是一个 3*1 的矩阵,也就是这个位置映射表可以认为是一个3行1列的映射表,即 全1矩阵的0 行 0 列映射到res的 0 行 1列,全1矩阵的1行0列映射到res的1行2列,全1矩阵的2行0列映射到res的2行3列,其他保持不变(其他都是0),dim=1这种操作就是制造了one-hot向量

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com