pytorch张量形状怎么改变

在PyTorch中,你可以使用reshape()、view()或expand()等方法来改变张量的形状。以下是这些方法的简要说明和示例:

reshape():此方法返回一个具有新形状的张量,但内存布局与原张量相同。如果新形状的元素总数与原张量不同,则会引发错误。

import torch
x = torch.tensor([1, 2, 3, 4])
y = x.reshape(2, 2)
print(y)

view():此方法返回一个具有新形状的张量,内存布局可能与原张量不同。为了使view()正常工作,新形状的元素总数必须与原张量相同。

import torch
x = torch.tensor([1, 2, 3, 4])
y = x.view(2, 2)
print(y)

expand():此方法返回一个具有新形状的张量,它是原张量的副本,内存布局相同。expand()不会修改原张量,而是创建一个新的张量。

import torch
x = torch.tensor([1, 2, 3, 4])
y = x.expand(2, 2)
print(y)

请注意,view()和reshape()方法在处理多维张量时可能会产生不同的结果。view()要求所有维度的大小都必须是整数,而reshape()则允许使用整数和负数。负数表示从末尾开始计数的维度大小。

例如,以下代码将创建一个形状为(3, -1)的张量:

import torch
x = torch.tensor([1, 2, 3, 4])
y = x.view(3, -1)
print(y)

这将输出:

tensor([[1., 2.],
        [3., 4.]])