pytorch教程网络和损失函数的可视化代码示例

作者:xz1308579340 时间:2023-11-26 16:13:51 

1.效果

pytorch教程网络和损失函数的可视化代码示例

2.环境

1.pytorch
2.visdom
3.python3.5

3.用到的代码


# coding:utf8
import torch
from torch import nn, optim   # nn 神经网络模块 optim优化函数模块
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
from visdom import Visdom  # 可视化处理模块
import time
import numpy as np
# 可视化app
viz = Visdom()
# 超参数
BATCH_SIZE = 40
LR = 1e-3
EPOCH = 2
# 判断是否使用gpu
USE_GPU = True
if USE_GPU:
   gpu_status = torch.cuda.is_available()
else:
   gpu_status = False
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
# 数据引入
train_dataset = datasets.MNIST('../data', True, transform, download=False)
test_dataset = datasets.MNIST('../data', False, transform)
train_loader = DataLoader(train_dataset, BATCH_SIZE, True)
# 为加快测试,把测试数据从10000缩小到2000
test_data = torch.unsqueeze(test_dataset.test_data, 1)[:1500]
test_label = test_dataset.test_labels[:1500]
# visdom可视化部分数据
viz.images(test_data[:100], nrow=10)
#viz.images(test_data[:100], nrow=10)
# 为防止可视化视窗重叠现象,停顿0.5秒
time.sleep(0.5)
if gpu_status:
   test_data = test_data.cuda()
test_data = Variable(test_data, volatile=True).float()
# 创建线图可视化窗口
line = viz.line(np.arange(10))
# 创建cnn神经网络
class CNN(nn.Module):
   def __init__(self, in_dim, n_class):
       super(CNN, self).__init__()
       self.conv = nn.Sequential(
           # channel 为信息高度 padding为图片留白 kernel_size 扫描模块size(5x5)
           nn.Conv2d(in_channels=in_dim, out_channels=16,kernel_size=5,stride=1, padding=2),
           nn.ReLU(),
           # 平面缩减 28x28 >> 14*14
           nn.MaxPool2d(kernel_size=2),
           nn.Conv2d(16, 32, 3, 1, 1),
           nn.ReLU(),
           # 14x14 >> 7x7
           nn.MaxPool2d(2)
       )
       self.fc = nn.Sequential(
           nn.Linear(32*7*7, 120),
           nn.Linear(120, n_class)
       )
   def forward(self, x):
       out = self.conv(x)
       out = out.view(out.size(0), -1)
       out = self.fc(out)
       return out
net = CNN(1,10)
if gpu_status :
   net = net.cuda()
   #print("#"*26, "使用gpu", "#"*26)
else:
   #print("#" * 26, "使用cpu", "#" * 26)
   pass
# loss、optimizer 函数设置
loss_f = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=LR)
# 起始时间设置
start_time = time.time()
# 可视化所需数据点
time_p, tr_acc, ts_acc, loss_p = [], [], [], []
# 创建可视化数据视窗
text = viz.text("<h1>convolution Nueral Network</h1>")
for epoch in range(EPOCH):
   # 由于分批次学习,输出loss为一批平均,需要累积or平均每个batch的loss,acc
   sum_loss, sum_acc, sum_step = 0., 0., 0.
   for i, (tx, ty) in enumerate(train_loader, 1):
       if gpu_status:
           tx, ty = tx.cuda(), ty.cuda()
       tx = Variable(tx)
       ty = Variable(ty)
       out = net(tx)
       loss = loss_f(out, ty)
       #print(tx.size())
       #print(ty.size())
       #print(out.size())
       sum_loss += loss.item()*len(ty)
       #print(sum_loss)
       pred_tr = torch.max(out,1)[1]
       sum_acc += sum(pred_tr==ty).item()
       sum_step += ty.size(0)
       # 学习反馈
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       # 每40个batch可视化一下数据
       if i % 40 == 0:
           if gpu_status:
               test_data = test_data.cuda()
           test_out = net(test_data)
           print(test_out.size())
           # 如果用gpu运行out数据为cuda格式需要.cpu()转化为cpu数据 在进行比较
           pred_ts = torch.max(test_out, 1)[1].cpu().data.squeeze()
           print(pred_ts.size())
           rightnum = pred_ts.eq(test_label.view_as(pred_ts)).sum().item()
           #rightnum =sum(pred_tr==ty).item()
           #  sum_acc += sum(pred_tr==ty).item()
           acc =  rightnum/float(test_label.size(0))
           print("epoch: [{}/{}] | Loss: {:.4f} | TR_acc: {:.4f} | TS_acc: {:.4f} | Time: {:.1f}".format(epoch+1, EPOCH,
                                   sum_loss/(sum_step), sum_acc/(sum_step), acc, time.time()-start_time))
           # 可视化部分
           time_p.append(time.time()-start_time)
           tr_acc.append(sum_acc/sum_step)
           ts_acc.append(acc)
           loss_p.append(sum_loss/sum_step)
           viz.line(X=np.column_stack((np.array(time_p), np.array(time_p), np.array(time_p))),
                    Y=np.column_stack((np.array(loss_p), np.array(tr_acc), np.array(ts_acc))),
                    win=line,
                    opts=dict(legend=["Loss", "TRAIN_acc", "TEST_acc"]))
           # visdom text 支持html语句
           viz.text("<p style='color:red'>epoch:{}</p><br><p style='color:blue'>Loss:{:.4f}</p><br>"
                    "<p style='color:BlueViolet'>TRAIN_acc:{:.4f}</p><br><p style='color:orange'>TEST_acc:{:.4f}</p><br>"
                    "<p style='color:green'>Time:{:.2f}</p>".format(epoch, sum_loss/sum_step, sum_acc/sum_step, acc,
                                                                      time.time()-start_time),
                    win=text)
           sum_loss, sum_acc, sum_step = 0., 0., 0.

来源:https://blog.csdn.net/xz1308579340/article/details/85015343

标签:pytorch,可视化,损失函数
0
投稿

猜你喜欢

  • Ubuntu+python将nii图像保存成png格式

    2022-05-11 10:04:51
  • 如何用python 操作MongoDB数据库

    2024-01-27 16:53:10
  • javascript forEach通用循环遍历方法

    2024-04-29 13:19:14
  • Python实现测试磁盘性能的方法

    2022-01-31 19:00:46
  • 详解python中xlrd包的安装与处理Excel表格

    2021-10-23 06:06:59
  • uploadify在Firefox下丢失session问题的解决方法

    2024-02-27 01:33:31
  • python实现截取屏幕保存文件,删除N天前截图的例子

    2021-09-19 18:13:49
  • python迭代器常见用法实例分析

    2023-07-12 02:40:54
  • Jsp生成页面验证码的方法[附代码]

    2023-06-25 07:46:42
  • VMware中安装CentOS7(设置静态IP地址)并通过docker容器安装mySql数据库(超详细教程)

    2024-01-14 02:58:23
  • 详细介绍基于MySQL的搜索引擎MySQL-Fullltext

    2024-01-27 16:30:30
  • mysql8.0.23 linux(centos7)安装完整超详细教程

    2024-01-18 23:59:43
  • python实现库存商品管理系统

    2023-06-01 06:37:29
  • JavaScipt中栈的实现方法

    2024-04-18 09:33:49
  • python 字符串详解

    2022-09-27 04:44:25
  • Oracle数据库下载及安装图文操作步骤

    2024-01-26 11:15:49
  • python中将\\\\uxxxx转换为Unicode字符串的方法

    2023-11-04 15:20:07
  • Selenium+BeautifulSoup+json获取Script标签内的json数据

    2023-06-17 09:30:18
  • python 数据挖掘算法的过程详解

    2022-11-17 09:09:19
  • Python实现softmax反向传播的示例代码

    2021-02-24 10:54:02
  • asp之家 网络编程 m.aspxhome.com