您的位置:首页 > 游戏 > 手游 > Pytorch中tensor与ndarray类型转换及标量转换

Pytorch中tensor与ndarray类型转换及标量转换

2024/12/23 5:44:15 来源:https://blog.csdn.net/bbaaa123/article/details/141355060  浏览:    关键词:Pytorch中tensor与ndarray类型转换及标量转换

tensor与ndarrary的转换

Pytorch中的tensor与ndarray在底层数据类型设计有相似之处,在Pytorch框架中tensor与ndarray可以较为方便地转换

tensor转ndarray

tensor转ndarray分为浅拷贝与深拷贝

浅拷贝

浅拷贝一般使用numpy()方法

import torch
import numpy as npdata1 = torch.tensor([1, 2, 3])
print(data1)
data2 = data1.numpy()
print(data2)
data1[0] = 9
print(data1)
print(data2)
# tensor([1, 2, 3])
# [1 2 3]
# tensor([9, 2, 3])
# [9 2 3]

可以看到,在对转换成ndarray类型的data2进行修改后,tensor的值也随之改变,这是因为二者底层共用一块,为浅拷贝

深拷贝

深拷贝我们可以对tensor进行clone()后再进行转换,clone()会拷贝一份完全独立的张量,并会拷贝计算图

import torch
import numpy as npdata1 = torch.tensor([1, 2, 3])
print(data1)
data2 = data1.clone().numpy()
print(data2)
data1[0] = 9
print(data1)
print(data2)
# tensor([1, 2, 3])
# [1 2 3]
# tensor([9, 2, 3])
# [1 2 3]

可以看到这里在对张量进行修改后,并不会影响ndarray,因为这里为深拷贝

ndarray转tensor

ndarray转tensor同样分为深拷贝和浅拷贝

浅拷贝

浅拷贝一般是通过torch.from_numpy()实现的

import torch
import numpy as npdata1 = np.array([1, 2, 3])
data2 = torch.from_numpy(data1)
print(data1)
print(data2)
data1[0] = 9
print(data1)
print(data2)
# [1 2 3]
# tensor([1, 2, 3], dtype=torch.int32)
# [9 2 3]
# tensor([9, 2, 3], dtype=torch.int32)

可以看到浅拷贝后,对共享内存的任意一个对象修改都会影响到另一个的值

深拷贝

深拷贝这里我们可以通过对ndarray进行copy()进行深拷贝创立副本

import torch
import numpy as npdata1 = np.array([1, 2, 3])
data2 = torch.from_numpy(data1.copy())
print(data1)
print(data2)
data1[0] = 9
print(data1)
print(data2)
# [1 2 3]
# tensor([1, 2, 3], dtype=torch.int32)
# [9 2 3]
# tensor([1, 2, 3], dtype=torch.int32)

张量提取标量

tensor可以分为矢量张量和标量张量,对于从张量中提取标量值一般可以使用item()方法,要求tensor为单个元素才可以使用

import torch
import numpy as npdata1 = torch.tensor(1)
data2 = torch.tensor([1])
print(data1)
print(data2)
print(data1.item())
print(data2.item())
# tensor(1)
# tensor([1])
# 1
# 1
import torch
import numpy as npdata1 = torch.tensor([1, 2, 3])
print(data1)print(data1.item())
tensor([1, 2, 3])
# Traceback (most recent call last):
#   File "D:\Pythonproject\teach_day_01\demo02.py", line 7, in <module>
#     print(data1.item())
#           ^^^^^^^^^^^^
# RuntimeError: a Tensor with 3 elements cannot be converted to Scalar

可以看到非标量张量无法进行item()标量值提取

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com