Pytorch提取模型特征向量保存至csv的例子

作者:朴素.无恙 时间:2022-09-28 00:41:17 

Pytorch提取模型特征向量


# -*- coding: utf-8 -*-
"""
dj
"""
import torch
import torch.nn as nn
import os
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.models as models
import pretrainedmodels
import pandas as pd
class FCViewer(nn.Module):
def forward(self, x):
 return x.view(x.size(0), -1)
class M(nn.Module):
def __init__(self, backbone1, drop, pretrained=True):
 super(M,self).__init__()
 if pretrained:
  img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained='imagenet')
 else:
  img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained=None)  
 self.img_encoder = list(img_model.children())[:-2]
 self.img_encoder.append(nn.AdaptiveAvgPool2d(1))
 self.img_encoder = nn.Sequential(*self.img_encoder)
 if drop > 0:
  self.img_fc = nn.Sequential(FCViewer())        
 else:
  self.img_fc = nn.Sequential(
   FCViewer())
def forward(self, x_img):
 x_img = self.img_encoder(x_img)
 x_img = self.img_fc(x_img)
 return x_img
model1=M('resnet18',0,pretrained=True)
features_dir = '/home/cc/Desktop/features'
transform1 = transforms.Compose([
 transforms.Resize(256),
 transforms.CenterCrop(224),
 transforms.ToTensor()])
file_path='/home/cc/Desktop/picture'
names = os.listdir(file_path)
print(names)
for name in names:
pic=file_path+'/'+name
img = Image.open(pic)
img1 = transform1(img)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
y = model1(x)
y = y.data.numpy()
y = y.tolist()
#print(y)
test=pd.DataFrame(data=y)
#print(test)
test.to_csv("/home/cc/Desktop/features/3.csv",mode='a+',index=None,header=None)

jiazaixunlianhaodemoxing


import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
class ResidualBlock(nn.Module):
def __init__(self, inchannel, outchannel, stride=1):
 super(ResidualBlock, self).__init__()
 self.left = nn.Sequential(
  nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
  nn.BatchNorm2d(outchannel),
  nn.ReLU(inplace=True),
  nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
  nn.BatchNorm2d(outchannel)
 )
 self.shortcut = nn.Sequential()
 if stride != 1 or inchannel != outchannel:
  self.shortcut = nn.Sequential(
   nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
   nn.BatchNorm2d(outchannel)
  )

def forward(self, x):
 out = self.left(x)
 out += self.shortcut(x)
 out = F.relu(out)
 return out

class ResNet(nn.Module):
def __init__(self, ResidualBlock, num_classes=10):
 super(ResNet, self).__init__()
 self.inchannel = 64
 self.conv1 = nn.Sequential(
  nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
  nn.BatchNorm2d(64),
  nn.ReLU(),
 )
 self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
 self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
 self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
 self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
 self.fc = nn.Linear(512, num_classes)

def make_layer(self, block, channels, num_blocks, stride):
 strides = [stride] + [1] * (num_blocks - 1) #strides=[1,1]
 layers = []
 for stride in strides:
  layers.append(block(self.inchannel, channels, stride))
  self.inchannel = channels
 return nn.Sequential(*layers)

def forward(self, x):
 out = self.conv1(x)
 out = self.layer1(out)
 out = self.layer2(out)
 out = self.layer3(out)
 out = self.layer4(out)
 out = F.avg_pool2d(out, 4)
 out = out.view(out.size(0), -1)
 out = self.fc(out)
 return out

def ResNet18():

return ResNet(ResidualBlock)

import os
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.models as models
import pretrainedmodels
import pandas as pd
class FCViewer(nn.Module):
def forward(self, x):
 return x.view(x.size(0), -1)
class M(nn.Module):
def __init__(self, backbone1, drop, pretrained=True):
 super(M,self).__init__()
 if pretrained:
  img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained='imagenet')
 else:
  img_model = ResNet18()
  we='/home/cc/Desktop/dj/model1/incption--7'
  # 模型定义-ResNet
  #net = ResNet18().to(device)
  img_model.load_state_dict(torch.load(we))#diaoyong  
 self.img_encoder = list(img_model.children())[:-2]
 self.img_encoder.append(nn.AdaptiveAvgPool2d(1))
 self.img_encoder = nn.Sequential(*self.img_encoder)
 if drop > 0:
  self.img_fc = nn.Sequential(FCViewer())        
 else:
  self.img_fc = nn.Sequential(
   FCViewer())
def forward(self, x_img):
 x_img = self.img_encoder(x_img)
 x_img = self.img_fc(x_img)
 return x_img
model1=M('resnet18',0,pretrained=None)
features_dir = '/home/cc/Desktop/features'
transform1 = transforms.Compose([
 transforms.Resize(56),
 transforms.CenterCrop(32),
 transforms.ToTensor()])
file_path='/home/cc/Desktop/picture'
names = os.listdir(file_path)
print(names)
for name in names:
pic=file_path+'/'+name
img = Image.open(pic)
img1 = transform1(img)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
y = model1(x)
y = y.data.numpy()
y = y.tolist()
#print(y)
test=pd.DataFrame(data=y)
#print(test)
test.to_csv("/home/cc/Desktop/features/3.csv",mode='a+',index=None,header=None)

来源:https://blog.csdn.net/weixin_40123108/article/details/90678916

标签:Pytorch,特征,向量,csv
0
投稿

猜你喜欢

  • JSP学生信息管理系统设计

    2023-07-13 03:37:30
  • CentOS系统中PHP和MySQL的升级方法

    2023-11-20 21:04:19
  • 《JavaScript语言精粹》

    2009-04-03 11:27:00
  • python 利用百度API进行淘宝评论关键词提取

    2021-11-14 19:32:36
  • python计算时间差的方法

    2023-05-19 16:08:23
  • Python图像分割之均匀性度量法分析

    2021-02-11 11:45:24
  • pyramid配置session的方法教程

    2021-04-26 09:23:37
  • 有序列表 li ol

    2008-07-30 12:31:00
  • centos7利用yum安装lnmp的教程(linux+nginx+php7.1+mysql5.7)

    2023-11-14 11:40:18
  • 记录下两个正则表达式的使用

    2009-11-30 12:56:00
  • 详解go语言中sort如何排序

    2023-09-03 14:00:38
  • python实现ftp文件传输功能

    2023-04-21 13:20:16
  • Python基于csv模块实现读取与写入csv数据的方法

    2023-04-12 23:14:34
  • Oracle三种上载文件技术

    2010-07-16 13:34:00
  • 使用numpy实现矩阵的翻转(flip)与旋转

    2023-01-31 01:03:18
  • HTML中事件触发列表与解说

    2007-10-22 12:50:00
  • Asp定时执行操作、各种网页定时操作详解

    2008-06-10 17:32:00
  • HTTP中header头部信息详解

    2023-06-11 23:33:17
  • 发散后的期望

    2008-07-31 18:32:00
  • Python3.x+pyqtgraph实现数据可视化教程

    2023-09-25 23:24:47
  • asp之家 网络编程 m.aspxhome.com