发布时间:2023-02-14 18:00
数据集在data文件夹里,分别训练集在“imgs”,label在“mask”里,数据集用的是医学影像细胞分割的样本,其实在train集里有原图和对应的样本图。
train_x的样图:
train_y(label)样图:
import argparse
import logging
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from eval import eval_net
from unet import UNet
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = \'data/imgs/\'
dir_mask = \'data/masks/\'
dir_checkpoint = \'checkpoints/\'
def train_net(net,
device,
epochs=2,
batch_size=1,
lr=0.001,
val_percent=0.1,
save_cp=True,
img_scale=0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
writer = SummaryWriter(comment=f\'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}\')
global_step = 0
logging.info(f\'\'\'Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
\'\'\')
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, \'min\' if net.n_classes > 1 else \'max\', patience=2)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f\'Epoch {epoch + 1}/{epochs}\', unit=\'img\') as pbar:
for batch in train_loader:
imgs = batch[\'image\']
true_masks = batch[\'mask\']
assert imgs.shape[1] == net.n_channels, \\
f\'Network has been defined with {net.n_channels} input channels, \' \\
f\'but loaded images have {imgs.shape[1]} channels. Please check that \' \\
\'the images are loaded correctly.\'
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
writer.add_scalar(\'Loss/train\', loss.item(), global_step)
pbar.set_postfix(**{\'loss (batch)\': loss.item()})
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
if global_step % (n_train // (10 * batch_size)) == 0:
for tag, value in net.named_parameters():
tag = tag.replace(\'.\', \'/\')
writer.add_histogram(\'weights/\' + tag, value.data.cpu().numpy(), global_step)
writer.add_histogram(\'grads/\' + tag, value.grad.data.cpu().numpy(), global_step)
val_score = eval_net(net, val_loader, device)
scheduler.step(val_score)
writer.add_scalar(\'learning_rate\', optimizer.param_groups[0][\'lr\'], global_step)
if net.n_classes > 1:
logging.info(\'Validation cross entropy: {}\'.format(val_score))
writer.add_scalar(\'Loss/test\', val_score, global_step)
else:
logging.info(\'Validation Dice Coeff: {}\'.format(val_score))
writer.add_scalar(\'Dice/test\', val_score, global_step)
writer.add_images(\'images\', imgs, global_step)
if net.n_classes == 1:
writer.add_images(\'masks/true\', true_masks, global_step)
writer.add_images(\'masks/pred\', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info(\'Created checkpoint directory\')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + f\'CP_epoch{epoch + 1}.pth\')
logging.info(f\'Checkpoint {epoch + 1} saved !\')
writer.close()
def get_args():
parser = argparse.ArgumentParser(description=\'Train the UNet on images and target masks\',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(\'-e\', \'--epochs\', metavar=\'E\', type=int, default=100,
help=\'Number of epochs\', dest=\'epochs\')
parser.add_argument(\'-b\', \'--batch-size\', metavar=\'B\', type=int, nargs=\'?\', default=1,
help=\'Batch size\', dest=\'batchsize\')
parser.add_argument(\'-l\', \'--learning-rate\', metavar=\'LR\', type=float, nargs=\'?\', default=0.00001,
help=\'Learning rate\', dest=\'lr\') #0.00001
parser.add_argument(\'-f\', \'--load\', dest=\'load\', type=str, default=False,
help=\'Load model from a .pth file\')
parser.add_argument(\'-s\', \'--scale\', dest=\'scale\', type=float, default=0.5,
help=\'Downscaling factor of the images\')#原先是0.5
parser.add_argument(\'-v\', \'--validation\', dest=\'val\', type=float, default=10.0,
help=\'Percent of the data that is used as validation (0-100)\')
return parser.parse_args()
if __name__ == \'__main__\':
logging.basicConfig(level=logging.INFO, format=\'%(levelname)s: %(message)s\')
args = get_args()
device = torch.device(\'cuda\' if torch.cuda.is_available() else \'cpu\')
logging.info(f\'Using device {device}\')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = UNet(n_channels=1, n_classes=1, bilinear=True)
logging.info(f\'Network:\\n\'
f\'\\t{net.n_channels} input channels\\n\'
f\'\\t{net.n_classes} output channels (classes)\\n\'
f\'\\t{\"Bilinear\" if net.bilinear else \"Transposed conv\"} upscaling\')
if args.load:
net.load_state_dict(
torch.load(args.load, map_location=device)
)
logging.info(f\'Model loaded from {args.load}\')
net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), \'INTERRUPTED.pth\')
logging.info(\'Saved interrupt\')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
(1)更改好训练集的地址(x,y)以后,指令输入python train.py即可开始训练。注意:根据电脑的内存和配置情况,可以选择batch_size的大小,另外epoch和学习率也要根据数据集的类型不同自己调参。
(2)保存的模型会根据步长自动保存在“checkpoint”文件下,选择最好的模型,改名为“MODEL.pth”放在根目录下。
(3)优化器选择“RMSprop”,评估指标选择交叉熵损失。
import argparse
import logging
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from unet import UNet
from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img)
if net.n_classes > 1:
probs = F.softmax(output, dim=1)
else:
probs = torch.sigmoid(output)
probs = probs.squeeze(0)
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(full_img.size[1]),
transforms.ToTensor()
]
)
probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()
return full_mask > out_threshold
def get_args():
parser = argparse.ArgumentParser(description=\'Predict masks from input images\',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(\'--model\', \'-m\', default=\'MODEL.pth\',
metavar=\'FILE\',
help=\"Specify the file in which the model is stored\")
parser.add_argument(\'--input\', \'-i\', metavar=\'INPUT\', nargs=\'+\',
help=\'filenames of input images\', required=True)
parser.add_argument(\'--output\', \'-o\', metavar=\'INPUT\', nargs=\'+\',
help=\'Filenames of ouput images\')
parser.add_argument(\'--viz\', \'-v\', action=\'store_true\',
help=\"Visualize the images as they are processed\",
default=False)
parser.add_argument(\'--no-save\', \'-n\', action=\'store_true\',
help=\"Do not save the output masks\",
default=False)
parser.add_argument(\'--mask-threshold\', \'-t\', type=float,
help=\"Minimum probability value to consider a mask pixel white\",
default=0.5)
parser.add_argument(\'--scale\', \'-s\', type=float,
help=\"Scale factor for the input images\",
default=0.5) #0.5
return parser.parse_args()
def get_output_filenames(args):
in_files = args.input
out_files = []
if not args.output:
for f in in_files:
pathsplit = os.path.splitext(f)
out_files.append(\"{}_OUT{}\".format(pathsplit[0], pathsplit[1]))
elif len(in_files) != len(args.output):
logging.error(\"Input files and output files are not of the same length\")
raise SystemExit()
else:
out_files = args.output
return out_files
def mask_to_image(mask):
# for i in range(mask.shape[0]):
# for j in range(mask.shape[1]):
# if mask[i, j] >0:
# mask[i,j]=255 #自己定
# print(mask*255)
return Image.fromarray((mask*255).astype(np.uint8))
if __name__ == \"__main__\":
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=1, n_classes=1)
logging.info(\"Loading model {}\".format(args.model))
device = torch.device(\'cuda\' if torch.cuda.is_available() else \'cpu\')
logging.info(f\'Using device {device}\')
net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info(\"Model loaded !\")
for i, fn in enumerate(in_files):
logging.info(\"\\nPredicting image {} ...\".format(fn))
img = Image.open(fn)
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)
if not args.no_save:
out_fn = out_files[i]
result = mask_to_image(mask)
result.save(out_files[i])
logging.info(\"Mask saved to {}\".format(out_files[i]))
if args.viz:
logging.info(\"Visualizing results for image {}, close to continue ...\".format(fn))
plot_img_and_mask(img, mask)
注意:
(1)运行预测时,可以输入python predict.py -i (这里输入预测图片的路径) -o output.jpg(这里某人输出文件在根目录下,也可以改变输出文件的位置)
(2)预测时用到的模型,默认是你放在根目录下改为“MODEL.pth”的模型,积极调参,可以获得更好的分割效果。
(3)如果想制作自己的数据集,也可以用Labelme等打标工具,自己制作。