Code for hypernetworks

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6


Code for hypernetworks

这篇文章将介绍怎么使用hypernetworks来完成一些实验,本实验基于https://github.com/g1910/HyperNetworks.git

主要的Class

​PrimaryNetwork​​​是主要观察的类,主要观察​​.forward​​中如何生成参数部分。

class PrimaryNetwork(nn.Module):

def __init__(self, z_dim=64):
super(PrimaryNetwork, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.bn1 = nn.BatchNorm2d(16)

self.z_dim = z_dim
self.hope = HyperNetwork(z_dim=self.z_dim)

self.zs_size = [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1],
[2, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2],
[4, 2], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4], [4, 4]]

self.filter_size = [[16,16], [16,16], [16,16], [16,16], [16,16], [16,16], [16,32], [32,32], [32,32], [32,32],
[32,32], [32,32], [32,64], [64,64], [64,64], [64,64], [64,64], [64,64]]

self.res_net = nn.ModuleList()

for i in range(18):
down_sample = False
if i > 5 and i % 6 == 0:
down_sample = True
self.res_net.append(ResNetBlock(self.filter_size[i][0], self.filter_size[i][1], downsample=down_sample))

self.zs = nn.ModuleList()

for i in range(36):
# 这里表示的是
self.zs.append(Embedding(self.zs_size[i], self.z_dim))

self.global_avg = nn.AvgPool2d(8)
self.final = nn.Linear(64,10)

def forward(self, x):

x = F.relu(self.bn1(self.conv1(x)))

'''
注意看这里,w1 w2是生成的权值,这个权值会用来在res_net中来参与计算。
这里是hypernetwork生成一个比较大的网络的主要部分
'''
for i in range(18):
# if i != 15 and i != 17:
w1 = self.zs[2*i](self.hope)
w2 = self.zs[2*i+1](self.hope)
x = self.res_net[i](x, w1, w2)

x = self.global_avg(x)
x = self.final(x.view(-1,64))

return x

同样重要的,还有​​Hypernetwork​

class HyperNetwork(nn.Module):

def __init__(self, f_size = 3, z_dim = 64, out_size=16, in_size=16):
super(HyperNetwork, self).__init__()
self.z_dim = z_dim
self.f_size = f_size
self.out_size = out_size
self.in_size = in_size

self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size*self.f_size*self.f_size)).cuda(),2))
self.b1 = Parameter(torch.fmod(torch.randn((self.out_size*self.f_size*self.f_size)).cuda(),2))

self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size*self.z_dim)).cuda(),2))
self.b2 = Parameter(torch.fmod(torch.randn((self.in_size*self.z_dim)).cuda(),2))

def forward(self, z):

h_in = torch.matmul(z, self.w2) + self.b2
h_in = h_in.view(self.in_size, self.z_dim)

h_final = torch.matmul(h_in, self.w1) + self.b1
kernel = h_final.view(self.out_size, self.in_size, self.f_size, self.f_size)

return kernel

训练的过程就很一致了,不在赘述

import torch
import torchvision
import torchvision.transforms as transforms

from torch.autograd import Variable
import torch.nn as nn

import argparse

import torch.optim as optim

from primary_net import PrimaryNetwork

########### Data Loader ###############

transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#############################

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()


############

net = PrimaryNetwork()
best_accuracy = 0.

if args.resume:
ckpt = torch.load('./hypernetworks_cifar_paper.pth')
net.load_state_dict(ckpt['net'])
best_accuracy = ckpt['acc']

net.cuda()

learning_rate = 0.002
weight_decay = 0.0005
milestones = [168000, 336000, 400000, 450000, 550000, 600000]
max_iter = 1000000

optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.5)
criterion = nn.CrossEntropyLoss()

total_iter = 0
epochs = 0
print_freq = 50
while total_iter < max_iter:

running_loss = 0.0

for i, data in enumerate(trainloader, 0):

inputs, labels = data

inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

optimizer.zero_grad()

outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()

optimizer.step()
lr_scheduler.step()

running_loss += loss.data[0]
if i % print_freq == (print_freq-1):
print("[Epoch %d, Total Iterations %6d] Loss: %.4f" % (epochs + 1, total_iter + 1, running_loss/print_freq))
running_loss = 0.0

total_iter += 1

epochs += 1

correct = 0.
total = 0.
for tdata in testloader:
timages, tlabels = tdata
toutputs = net(Variable(timages.cuda()))
_, predicted = torch.max(toutputs.cpu().data, 1)
total += tlabels.size(0)
correct += (predicted == tlabels).sum()

accuracy = (100. * correct) / total
print('After epoch %d, accuracy: %.4f %%' % (epochs, accuracy))

if accuracy > best_accuracy:
print('Saving model...')
state = {
'net': net.state_dict(),
'acc': accuracy
}
torch.save(state, './hypernetworks_cifar_paper.pth')
best_accuracy = accuracy

print('Finished Training')


阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6