Pytorch PyG实现EdgePool图分类

作者:实力 时间:2022-03-17 15:47:55 

EdgePool简介

EdgePool是一种用于图分类的卷积神经网络(Convolutional Neural Network,CNN)模型。其主要思想是通过 edge pooling 上下采样优化图像大小,减少空间复杂度,提高分类性能。

实现步骤

 数据准备

一般来讲,在构建较大规模数据集时,我们都需要对数据进行规范、归一和清洗处理,以便后续语义分析或深度学习操作。而在图像数据集中,则需使用特定的框架或工具库完成。

# 导入MNIST数据集
from torch_geometric.datasets import MNISTSuperpixels
# 加载数据、划分训练集和测试集
dataset = MNISTSuperpixels(root='./mnist', transform=Compose([ToTensor(), NormalizeMeanStd()]))
data = dataset[0]
# 定义超级参数
num_features = dataset.num_features
num_classes = dataset.num_classes
# 构建训练集和测试集索引文件
train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)
train_mask[:60000] = 1
test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)
test_mask[60000:] = 1
# 创建数据加载器
train_loader = DataLoader(data[train_mask], batch_size=32, shuffle=True)
test_loader = DataLoader(data[test_mask], batch_size=32, shuffle=False)

实现模型

在定义EdgePool模型时,我们需要重新考虑网络结构中的上下采样操作,以便让整个网络拥有更强大的表达能力,从而学习到更复杂的关系。

from torch.nn import Linear
from torch_geometric.nn import EdgePooling
class EdgePool(torch.nn.Module):
   def __init__(self, dataset):
       super(EdgePool, self).__init__()
       # 定义输入与输出维度数
       self.input_dim = dataset.num_features
       self.hidden_dim = 128
       self.output_dim = 10
       # 定义卷积层、归一化层和pooling层等
       self.conv1 = GCNConv(self.input_dim, self.hidden_dim)
       self.norm1 = BatchNorm1d(self.hidden_dim)
       self.pool1 = EdgePooling(self.hidden_dim)
       self.conv2 = GCNConv(self.hidden_dim, self.hidden_dim)
       self.norm2 = BatchNorm1d(self.hidden_dim)
       self.pool2 = EdgePooling(self.hidden_dim)
       self.conv3 = GCNConv(self.hidden_dim, self.hidden_dim)
       self.norm3 = BatchNorm1d(self.hidden_dim)
       self.pool3 = EdgePooling(self.hidden_dim)
       self.lin = torch.nn.Linear(self.hidden_dim, self.output_dim)
   def forward(self, x, edge_index, batch):
       x = F.relu(self.norm1(self.conv1(x, edge_index)))
       x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
       x = F.relu(self.norm2(self.conv2(x, edge_index)))
       x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
       x = F.relu(self.norm3(self.conv3(x, edge_index)))
       x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
       x = global_mean_pool(x, batch)
       x = self.lin(x)
       return x

在上述代码中,我们使用了不同的卷积层、池化层和全连接层等神经网络功能块来构建EdgePool模型。其中,每个 GCNConv 层被保持为128的隐藏尺寸;BatchNorm1d是一种旨在提高收敛速度并增强网络泛化能力的方法;EdgePooling是一种在 GraphConvolution 上附加的特殊类别,它将给定图下采样至其一半的大小,并返回缩小后的图与两个跟踪full-graph-to-pool双向映射(keep and senders)的 edge index(edgendarcs)。 在这种情况下传递 None ,表明 batch 未更改。

模型训练

在定义好 EdgePool 网络结构之后,需要指定合适的优化器、损失函数,并控制训练轮数、批量大小与学习率等超参数。同时还要记录大量日志信息,方便后期跟踪和驾驶员。

# 定义训练计划,包括损失函数、优化器及迭代次数等
train_epochs = 50
learning_rate = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(edge_pool.parameters(), lr=learning_rate)
losses_per_epoch = []
accuracies_per_epoch = []
for epoch in range(train_epochs):
   running_loss = 0.0
   running_corrects = 0.0
   count = 0.0
   for samples in train_loader:
       optimizer.zero_grad()
       x, edge_index, batch = samples.x, samples.edge_index, samples.batch
       out = edge_pool(x, edge_index, batch)
       label = samples.y
       loss = criterion(out, label)
       loss.backward()
       optimizer.step()
       running_loss += loss.item() / len(train_loader.dataset)
       pred = out.argmax(dim=1)
       running_corrects += pred.eq(label).sum().item() / len(train_loader.dataset)
       count += 1
   losses_per_epoch.append(running_loss)
   accuracies_per_epoch.append(running_corrects)
   if (epoch + 1) % 10 == 0:
       print("Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}".format(
           epoch + 1, train_epochs, running_loss, running_corrects))

在训练过程中,我们遍历了每个批次的数据,并通过反向传播算法进行优化,并更新了 loss 和 accuracy 输出值。 同时方便可视化与记录,需要将训练过程中的 loss 和 accuracy 输出到相应的容器中,以便后期进行分析和处理。

来源:https://juejin.cn/post/7224127112709652538

标签:Pytorch,PyG,EdgePool,图分类
0
投稿

猜你喜欢

  • python编程开发之textwrap文本样式处理技巧

    2022-03-20 18:48:26
  • Python中 map()函数的用法详解

    2021-03-25 08:16:44
  • Python爬虫突破反爬虫机制知识点总结

    2021-09-17 12:38:22
  • 将 Ubuntu 16 和 18 上的 python 升级到最新 python3.8 的方法教程

    2022-12-16 07:50:17
  • mysql常用备份命令和shell备份脚本分享

    2024-01-13 14:37:35
  • 实例讲解Python爬取网页数据

    2023-01-10 03:55:05
  • python类中super()和__init__()的区别

    2021-04-17 16:03:02
  • PHP日期和时间函数的使用示例详解

    2023-06-28 07:28:25
  • 百万级asp分页存储过程代码(ver2.0)

    2007-12-17 13:13:00
  • 教你为SQL Server数据库构造安全门

    2009-01-20 11:34:00
  • TensorFlow MNIST手写数据集的实现方法

    2022-12-19 19:45:02
  • 浅谈python 导入模块和解决文件句柄找不到问题

    2023-12-07 03:40:07
  • JavaScript十二月新标准ECMA262v5快速浏览

    2009-12-27 12:56:00
  • 简述Redis和MySQL的区别

    2024-01-25 01:18:34
  • Python安装第三方库的3种方法

    2022-02-03 03:10:47
  • 蚁群算法js版

    2008-10-08 10:15:00
  • python通过urllib2获取带有中文参数url内容的方法

    2022-07-26 10:35:58
  • C#实现Excel表数据导入Sql Server数据库中的方法

    2024-01-19 01:19:01
  • 在oracle 数据库查询的select 查询字段中关联其他表的方法

    2009-08-31 12:27:00
  • python 三种方法提取pdf中的图片

    2023-09-18 08:25:58
  • asp之家 网络编程 m.aspxhome.com