U-net图像分割网络的用法(pytorch-python)

发布时间:2023-02-14 18:00

1.介绍数据集

数据集在data文件夹里,分别训练集在“imgs”,label在“mask”里,数据集用的是医学影像细胞分割的样本,其实在train集里有原图和对应的样本图。
train_x的样图:
\"U-net图像分割网络的用法(pytorch-python)_第1张图片\"
train_y(label)样图:
\"U-net图像分割网络的用法(pytorch-python)_第2张图片\"

2.训练模型

import argparse
import logging
import os
import sys

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm

from eval import eval_net
from unet import UNet

from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split

dir_img = \'data/imgs/\'
dir_mask = \'data/masks/\'
dir_checkpoint = \'checkpoints/\'


def train_net(net,
              device,
              epochs=2,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

    writer = SummaryWriter(comment=f\'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}\')
    global_step = 0

    logging.info(f\'\'\'Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    \'\'\')

    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, \'min\' if net.n_classes > 1 else \'max\', patience=2)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train, desc=f\'Epoch {epoch + 1}/{epochs}\', unit=\'img\') as pbar:
            for batch in train_loader:
                imgs = batch[\'image\']
                true_masks = batch[\'mask\']
                assert imgs.shape[1] == net.n_channels, \\
                    f\'Network has been defined with {net.n_channels} input channels, \' \\
                    f\'but loaded images have {imgs.shape[1]} channels. Please check that \' \\
                    \'the images are loaded correctly.\'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar(\'Loss/train\', loss.item(), global_step)

                pbar.set_postfix(**{\'loss (batch)\': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace(\'.\', \'/\')
                        writer.add_histogram(\'weights/\' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram(\'grads/\' + tag, value.grad.data.cpu().numpy(), global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar(\'learning_rate\', optimizer.param_groups[0][\'lr\'], global_step)

                    if net.n_classes > 1:
                        logging.info(\'Validation cross entropy: {}\'.format(val_score))
                        writer.add_scalar(\'Loss/test\', val_score, global_step)
                    else:
                        logging.info(\'Validation Dice Coeff: {}\'.format(val_score))
                        writer.add_scalar(\'Dice/test\', val_score, global_step)

                    writer.add_images(\'images\', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images(\'masks/true\', true_masks, global_step)
                        writer.add_images(\'masks/pred\', torch.sigmoid(masks_pred) > 0.5, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info(\'Created checkpoint directory\')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f\'CP_epoch{epoch + 1}.pth\')
            logging.info(f\'Checkpoint {epoch + 1} saved !\')

    writer.close()


def get_args():
    parser = argparse.ArgumentParser(description=\'Train the UNet on images and target masks\',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(\'-e\', \'--epochs\', metavar=\'E\', type=int, default=100,
                        help=\'Number of epochs\', dest=\'epochs\')
    parser.add_argument(\'-b\', \'--batch-size\', metavar=\'B\', type=int, nargs=\'?\', default=1,
                        help=\'Batch size\', dest=\'batchsize\')
    parser.add_argument(\'-l\', \'--learning-rate\', metavar=\'LR\', type=float, nargs=\'?\', default=0.00001,
                        help=\'Learning rate\', dest=\'lr\') #0.00001
    parser.add_argument(\'-f\', \'--load\', dest=\'load\', type=str, default=False,
                        help=\'Load model from a .pth file\')
    parser.add_argument(\'-s\', \'--scale\', dest=\'scale\', type=float, default=0.5,
                        help=\'Downscaling factor of the images\')#原先是0.5
    parser.add_argument(\'-v\', \'--validation\', dest=\'val\', type=float, default=10.0,
                        help=\'Percent of the data that is used as validation (0-100)\')

    return parser.parse_args()


if __name__ == \'__main__\':
    logging.basicConfig(level=logging.INFO, format=\'%(levelname)s: %(message)s\')
    args = get_args()
    device = torch.device(\'cuda\' if torch.cuda.is_available() else \'cpu\')
    logging.info(f\'Using device {device}\')

    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    #   - For 1 class and background, use n_classes=1
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    net = UNet(n_channels=1, n_classes=1, bilinear=True)
    logging.info(f\'Network:\\n\'
                 f\'\\t{net.n_channels} input channels\\n\'
                 f\'\\t{net.n_classes} output channels (classes)\\n\'
                 f\'\\t{\"Bilinear\" if net.bilinear else \"Transposed conv\"} upscaling\')

    if args.load:
        net.load_state_dict(
            torch.load(args.load, map_location=device)
        )
        logging.info(f\'Model loaded from {args.load}\')

    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), \'INTERRUPTED.pth\')
        logging.info(\'Saved interrupt\')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)

(1)更改好训练集的地址(x,y)以后,指令输入python train.py即可开始训练。注意:根据电脑的内存和配置情况,可以选择batch_size的大小,另外epoch和学习率也要根据数据集的类型不同自己调参。
(2)保存的模型会根据步长自动保存在“checkpoint”文件下,选择最好的模型,改名为“MODEL.pth”放在根目录下。
(3)优化器选择“RMSprop”,评估指标选择交叉熵损失。

3.导入模型,随后预测

import argparse
import logging
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from unet import UNet
from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset


def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(full_img.size[1]),
                transforms.ToTensor()
            ]
        )

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold


def get_args():
    parser = argparse.ArgumentParser(description=\'Predict masks from input images\',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(\'--model\', \'-m\', default=\'MODEL.pth\',
                        metavar=\'FILE\',
                        help=\"Specify the file in which the model is stored\")
    parser.add_argument(\'--input\', \'-i\', metavar=\'INPUT\', nargs=\'+\',
                        help=\'filenames of input images\', required=True)

    parser.add_argument(\'--output\', \'-o\', metavar=\'INPUT\', nargs=\'+\',
                        help=\'Filenames of ouput images\')
    parser.add_argument(\'--viz\', \'-v\', action=\'store_true\',
                        help=\"Visualize the images as they are processed\",
                        default=False)
    parser.add_argument(\'--no-save\', \'-n\', action=\'store_true\',
                        help=\"Do not save the output masks\",
                        default=False)
    parser.add_argument(\'--mask-threshold\', \'-t\', type=float,
                        help=\"Minimum probability value to consider a mask pixel white\",
                        default=0.5)
    parser.add_argument(\'--scale\', \'-s\', type=float,
                        help=\"Scale factor for the input images\",
                        default=0.5) #0.5

    return parser.parse_args()


def get_output_filenames(args):
    in_files = args.input
    out_files = []

    if not args.output:
        for f in in_files:
            pathsplit = os.path.splitext(f)
            out_files.append(\"{}_OUT{}\".format(pathsplit[0], pathsplit[1]))
    elif len(in_files) != len(args.output):
        logging.error(\"Input files and output files are not of the same length\")
        raise SystemExit()
    else:
        out_files = args.output

    return out_files


def mask_to_image(mask):
    # for i in range(mask.shape[0]):
    #         for j in range(mask.shape[1]):
    #             if mask[i, j] >0:
    #                 mask[i,j]=255 #自己定
    # print(mask*255)
    return Image.fromarray((mask*255).astype(np.uint8))


if __name__ == \"__main__\":
    args = get_args()
    in_files = args.input
    out_files = get_output_filenames(args)

    net = UNet(n_channels=1, n_classes=1)

    logging.info(\"Loading model {}\".format(args.model))

    device = torch.device(\'cuda\' if torch.cuda.is_available() else \'cpu\')
    logging.info(f\'Using device {device}\')
    net.to(device=device)
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info(\"Model loaded !\")

    for i, fn in enumerate(in_files):
        logging.info(\"\\nPredicting image {} ...\".format(fn))

        img = Image.open(fn)

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           device=device)

        if not args.no_save:
            out_fn = out_files[i]
            result = mask_to_image(mask)
            result.save(out_files[i])

            logging.info(\"Mask saved to {}\".format(out_files[i]))

        if args.viz:
            logging.info(\"Visualizing results for image {}, close to continue ...\".format(fn))
            plot_img_and_mask(img, mask)

注意:
(1)运行预测时,可以输入python predict.py -i (这里输入预测图片的路径) -o output.jpg(这里某人输出文件在根目录下,也可以改变输出文件的位置)
(2)预测时用到的模型,默认是你放在根目录下改为“MODEL.pth”的模型,积极调参,可以获得更好的分割效果。
(3)如果想制作自己的数据集,也可以用Labelme等打标工具,自己制作。

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号