(扩展)图结构的批处理

常规数据的批处理

常规的数据的批处理就是简单地将多个样本“堆叠”成更大的张量。比如图像数据,假如每个图像数据是一个三通道 32×32 大小的图片,那么将他们堆叠起来就可以制作 Batch。

import torch

# each sample is a image, whose shape is 3x32x32
batch_images = []
for _ in range 10:
    batch_images.append(torch.randn(3, 32, 32))

# stack to make the batch
batch = torch.stack(batch_image, dim=0)  # shape: [10, 3, 32, 32]

由此可见常规的数据批处理的一个前提是每个样本张量的维度是相同的。但是如果是分子图数据的话,由于不同分子的原子数不同,对应的样本张量的维度各不相同。因此 PyG 使用了另一种批处理数据的手法。

图数据的批处理

由于不同的图数据节点数和边数都不尽相同,所以处理的时候无法简单地“堆叠”起来,所以处理方法为,将多个图合并为一个大的超级图,并通过额外的 batch 熟悉标记每个节点的归属。

举个简单的例子,假如有三个分子,甲烷、氨气和水,截断半径设成 1.5 Å,每个单独的图属性如下

>>> methane.atomic_numbers
tensor([6, 1, 1, 1, 1])
>>> methane.pos
tensor([[ 0.00000,  0.00000,  0.00000],
        [ 0.64051, -0.64051,  0.64051],
        [ 0.64051,  0.64051, -0.64051],
        [-0.64051,  0.64051,  0.64051],
        [-0.64051, -0.64051, -0.64051]])
>>> methane.edge_index
tensor([[0, 0, 0, 0, 1, 2, 3, 4],
        [1, 2, 3, 4, 0, 0, 0, 0]])
>>> 
>>> ammonia.atomic_numbers
tensor([7, 1, 1, 1])
>>> ammonia.pos
tensor([[ 0.00000,  0.00000, -0.09432],
        [ 0.49170, -0.85165,  0.22008],
        [ 0.49170,  0.85165,  0.22008],
        [-0.98341,  0.00000,  0.22008]])
>>> ammonia.edge_index
tensor([[0, 0, 0, 1, 2, 3],
        [1, 2, 3, 0, 0, 0]])
>>> 
>>> water.atomic_numbers
tensor([8, 1, 1])
>>> water.pos
tensor([[0.00000, -0.01840,  0.00000],
        [0.00000,  0.53834, -0.78305],
        [0.00000,  0.53834,  0.78305]])
>>> water.edge_index
tensor([[0, 0, 1, 2],
        [1, 2, 0, 0]])

三个分子分别由 5、4 和 3 个原子组成,制作 Batch 的时候将他们拼成一张大图,包含 5+4+3=12 个原子,atomic_numberspos 等原子级别的属性将会拼接起来(Concatenate),之后使用 batch 属性变量记录每个原子所属的子图:

>>> batch_data.atomic_numbers
tensor([6, 1, 1, 1, 1, 7, 1, 1, 1, 8, 1, 1])
>>> batch_data.pos
tensor([[ 0.00000,  0.00000,  0.00000],
        [ 0.64051, -0.64051,  0.64051],
        [ 0.64051,  0.64051, -0.64051],
        [-0.64051,  0.64051,  0.64051],
        [-0.64051, -0.64051, -0.64051],
        [ 0.00000,  0.00000, -0.09432],
        [ 0.49170, -0.85165,  0.22008],
        [ 0.49170,  0.85165,  0.22008],
        [-0.98341,  0.00000,  0.22008],
        [ 0.00000, -0.01840,  0.00000],
        [ 0.00000,  0.53834, -0.78305],
        [ 0.00000,  0.53834,  0.78305]])
>>> batch_data.batch
tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2])

同理,其他的原子级别的性质或中间张量(如节点特征)也如此拼接起来即可。而 edge_index 用作边的索引则会自动偏移

>>> batch_data.edge_index
tensor([[0, 0, 0, 0, 1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 9, 9, 10, 11],
        [1, 2, 3, 4, 0, 0, 0, 0, 6, 7, 8, 5, 5, 5, 10, 11, 9, 9]])

这样记录边的索引之后,有关边的张量(如 RBF 展开,边长度等)也只需拼接起来即可。尽管以一张大图方式储存,由于不同子图之间的节点是没有连接的,所呈现的依然是一张张小子图。

最后图级别的属性或性质(如能量、晶胞等)只要像传统批处理一样堆叠即可,不过实际 PyG 的 Data 类中不会花心思去判断你的性质是图级别的、原子级别的、还是边级别的,统一都是使用的 torch.cat 而非 torch.stack,因此图级别的性质也需要留一个维度以供 Batch 操作。比如能量的形状应为 [1,],晶胞张量的形状应为 [1, 3, 3] 等等。