先从最简单的模型搭建开始,输入图像大小是3x224x224,卷积部分使用的是VGG11模型,经过第5个maxpooling后开始上采样,经过5个反卷积层还原成原始图像大小。
model.py:
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.encode1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.MaxPool2d(2, 2)
)
self.encode2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.MaxPool2d(2, 2)
)
self.encode3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.MaxPool2d(2, 2)
)
self.encode4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.MaxPool2d(2, 2)
)
self.encode5 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.Conv2d(512, 512, 3, 1, 1),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.MaxPool2d(2, 2)
)
self.decode1 = nn.Sequential(
nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3,
stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True)
)
self.decode2 = nn.Sequential(
nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(True)
)
self.decode3 = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.decode4 = nn.Sequential(
nn.ConvTranspose2d(64, 32, 3, 2, 1, 1),
nn.BatchNorm2d(32),
nn.ReLU(True)
)
self.decode5 = nn.Sequential(
nn.ConvTranspose2d(32, 16, 3, 2, 1, 1),
nn.BatchNorm2d(16),
nn.ReLU(True)
)
self.classifier = nn.Conv2d(16, 2, kernel_size=1)
def forward(self, x): # b: batch_size
out = self.encode1(x) # [b, 3, 224, 224] => [b, 64, 112, 112]
out = self.encode2(out) # [b, 64, 112, 112] => [b, 128, 56, 56]
out = self.encode3(out) # [b, 128, 56, 56] => [b, 256, 28, 28]
out = self.encode4(out) # [b, 256, 28, 28] => [b, 512, 14, 14]
out = self.encode5(out) # [b, 512, 14, 14] => [b, 512, 7, 7]
out = self.decode1(out) # [b, 512, 7, 7] => [b, 256, 14, 14]
out = self.decode2(out) # [b, 256, 14, 14] => [b, 128, 28, 28]
out = self.decode3(out) # [b, 128, 28, 28] => [b, 64, 56, 56]
out = self.decode4(out) # [b, 64, 56, 56] => [b, 32, 112, 112]
out = self.decode5(out) # [b, 32, 112, 112] => [b, 16, 224, 224]
out = self.classifier(out) # [b, 16, 224, 224] => [b, 2, 224, 224] 2表示类别数,目标和非目标两类
return out
if __name__ == '__main__':
img = torch.randn(2, 3, 224, 224)
net = Net()
sample = net(img)
print(sample.shape)
数据存放格式如下所示,图像放在last里,标签放在last_msk里。
├─data
├─test
│ ├─last
│ └─last_msk
└─train
├─last
└─last_msk
last:
from torch.utils.data import Dataset
import os
import cv2
import numpy as np
class MyDataset(Dataset):
def __init__(self, train_path, transform=None):
self.images = os.listdir(train_path + '/last')
self.labels = os.listdir(train_path + '/last_msk')
assert len(self.images) == len(self.labels), 'Number does not match'
self.transform = transform
self.images_and_labels = [] # 存储图像和标签路径
for i in range(len(self.images)):
self.images_and_labels.append((train_path + '/last/' + self.images[i], train_path + '/last_msk/' + self.labels[i]))
def __getitem__(self, item):
img_path, lab_path = self.images_and_labels[item]
img = cv2.imread(img_path)
img = cv2.resize(img, (224, 224))
lab = cv2.imread(lab_path, 0)
lab = cv2.resize(lab, (224, 224))
lab = lab / 255 # 转换成0和1
lab = lab.astype('uint8') # 不为1的全置为0
lab = np.eye(2)[lab] # one-hot编码
lab = np.array(list(map(lambda x: abs(x-1), lab))).astype('float32') # 将所有0变为1(1对应255, 白色背景),所有1变为0(黑色,目标)
lab = lab.transpose(2, 0, 1) # [224, 224, 2] => [2, 224, 224]
if self.transform is not None:
img = self.transform(img)
return img, lab
def __len__(self):
return len(self.images)
if __name__ == '__main__':
img = cv2.imread('data/train/last_msk/150.jpg', 0)
img = cv2.resize(img, (16, 16))
img2 = img/255
img3 = img2.astype('uint8')
hot1 = np.eye(2)[img3]
hot2 = np.array(list(map(lambda x: abs(x-1), hot1)))
print(hot2.shape)
print(hot2.transpose(2, 0, 1))
直接运行load_img.py可查看编码后的一张标签图像矩阵。
150.jpg:
(16, 16, 2)
[[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
[1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
[1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
[1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
[1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
[0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]
train.py:
import os
import model
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from load_img import MyDataset
from torchvision import transforms
from torch.utils.data import DataLoader
batchsize = 8
epochs = 50
train_data_path = 'data/train'
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
bag = MyDataset(train_data_path, transform)
dataloader = DataLoader(bag, batch_size=batchsize, shuffle=True)
device = torch.device('cuda')
net = model.Net().to(device)
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=0.7)
if not os.path.exists('checkpoints'):
os.mkdir('checkpoints')
for epoch in range(1, epochs+1):
for batch_idx, (img, lab) in enumerate(dataloader):
img, lab = img.to(device), lab.to(device)
output = torch.sigmoid(net(img))
loss = criterion(output, lab)
output_np = output.cpu().data.numpy().copy()
output_np = np.argmin(output_np, axis=1)
y_np = lab.cpu().data.numpy().copy()
y_np = np.argmin(y_np, axis=1)
if batch_idx % 20 == 0:
print('Epoch:[{}/{}]\tStep:[{}/{}]\tLoss:{:.6f}'.format(
epoch, epochs, (batch_idx+1)*len(img), len(dataloader.dataset), loss.item()
))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
torch.save(net, 'checkpoints/model_epoch_{}.pth'.format(epoch))
print('checkpoints/model_epoch_{}.pth saved!'.format(epoch))
Epoch:[1/50] Step:[8/499] Loss:0.702611
Epoch:[1/50] Step:[168/499] Loss:0.697093
Epoch:[1/50] Step:[328/499] Loss:0.686626
Epoch:[1/50] Step:[488/499] Loss:0.676049
Epoch:[2/50] Step:[8/499] Loss:0.667989
Epoch:[2/50] Step:[168/499] Loss:0.664439
Epoch:[2/50] Step:[328/499] Loss:0.638619
Epoch:[2/50] Step:[488/499] Loss:0.636599
Epoch:[3/50] Step:[8/499] Loss:0.616667
import torch
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import os
class TestDataset(Dataset):
def __init__(self, test_img_path, transform=None):
self.test_img = os.listdir(test_img_path)
self.transform = transform
self.images = []
for i in range(len(self.test_img)):
self.images.append(os.path.join(test_img_path, self.test_img[i]))
def __getitem__(self, item):
img_path = self.images[item]
img = cv2.imread(img_path)
img = cv2.resize(img, (224, 224))
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.test_img)
test_img_path = 'data/test/last'
checkpoint_path = 'checkpoints/model_epoch_50.pth'
save_dir = 'result'
if not os.path.exists(save_dir ):
os.mkdir(save_dir )
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
bag = TestDataset(test_img_path, transform)
dataloader = DataLoader(bag, batch_size=1, shuffle=None)
net = torch.load(checkpoint_path)
net = net.cuda()
for idx, img in enumerate(dataloader):
img = img.cuda()
output = torch.sigmoid(net(img))
output_np = output.cpu().data.numpy().copy()
output_np = np.argmin(output_np, axis=1)
img_arr = np.squeeze(output_np)
img_arr = img_arr*255
cv2.imwrite('%s/%03d.png'%(save_dir, idx), img_arr)
print('%s/%03d.png'%(save_dir, idx))
全部代码:
因篇幅问题不能全部显示,请点此查看更多更全内容