论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现

发布时间:2024-06-03 09:01

本博文是MWCNN的阅读笔记,论文的链接:https://arxiv.org/pdf/1805.07071.pdf

代码:https://github.com/lpj0/MWCNN(仅仅是matlab代码)

通过参考代码,对该网络在pytorch框架下进行复现

 

网络结构如下图所示

论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现_第1张图片

incorporating residual block in each level of the encoder and decoder(在编码器和解码器中加入residual block)

in each level we adopt discrete wavelet transform (DWT) as the downsamping operator in encoder and inverse wavelet transform (IWT) as upsampling operator in decoder.

And 3 X 3 convolution is deployed to compress and expand the feature maps after DWT and IWT, respectively.

For each level, two residual blocks are further deployed to enhance feature representation and reconstruction.

 

论文原文解读

Image restoration, which aims to recover the latent clean image x from its degraded observation y, is a fundamental and long-standing problem in low level vision.

For image restoration, CNN actually represents a mapping from degraded observation to latent clean image.

one representative strategy is to use the fully convolutional network (FCN) by removing the pooling layers. In general, larger receptive field is helpful to restoration performance by taking more spatial context into account. However, for FCN without pooling, the receptive field size can be enlarged by either increasing the network depth or using filters with larger size, which unexceptionally results in higher computational cost.

dilated filtering 可以用于enlarge receptive field without the sacrifice of computational cost. However,inherently suffers from gridding effect(固有地受到网格效应的影响)where the receptive field only considers a sparse sampling of input image with checkerboard patterns.(其中感受野仅考虑带有棋盘图案的输入图像的稀疏采样)

论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现_第2张图片

在LR领域,感受野的大小和效率是一个trade off的关系。普通卷积网络(CNN)通常以牺牲计算成本为代价来扩大感受野。而本文,作者提出的multi-level wavelet CNN (MWCNN) model就是为了更好的在感受野大小和计算效率之间取一个trade off的关系。With the modified U-Net architecture, wavelet transform is introduced to reduce the size of feature maps in the contracting subnetwork.(通过改进的U-Net架构,引入小波变换以减小签约子网中的特征映射的大小。)进一步地再采用一个卷积层来进一步减少特征图的channel的数目。在拓展的子网络中,inverse wavelet transform is then deployed to reconstruct the high resolution feature maps(部署逆小波变换以重建高分辨率特征图。)并且通过扩张滤波和子采样,网络可以应用于其他图像复原的任务中

enlarge receptive field for better tradeoff between performance and efficiency。MWCNN基于U-Net architecture consisting of a contracting subnetwork and an expanding subnetwork(由收缩子网和扩展子网组成。)

在收缩子网络中采用discrete wavelet transform (DWT)以替换每个池操作。由于DWT是可逆的,故此所有的信息都可以被保存,通过这样的一个下采样方案。进一步地,DWT计算feature map的频率与位置信息,这可能有利于恢复细节信息。

In the expanding subnetwork, inverse wavelet transform (IWT) is utilized for upsampling low resolution feature maps to high resolution ones.

To enrich feature representation and reduce computational burden, elementwise summation is adopted for combining the feature maps from the contracting and expanding subnetworks.(为了丰富特征表示并减少计算负担,采用 elementwise summation来结合收缩和扩展子网的特征映射。)

 

Multi-level wavelet packet transform (WPT)——多级小波包变换

论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现_第3张图片

上图表示了WPT对于单幅图像的分解和重构。实际上,WPT是FCN中没有非线性层的特殊情况。而本文的MWCNN就是在WPT的基础删,再增加了卷积层。卷积层位于任意两个level的DWTs中,如下图所示

论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现_第4张图片

在每个变换级别之后,将所有子带图像作为CNN块的输入,以学习紧凑表示作为后续变换级别的输入(After each level of transform, all the subband images are taken as the inputs to a CNN block to learn a compact representation as the inputs to the subsequent level of transform.)。MWCNN is a generalization of multi-level WPT, and degrades to WPT when each CNN block becomes the identity mapping.

MWCNN can use subsampling operations safely without information loss。Moreover, compared with conventional CNN, the frequency and location characteristics of DWT is also expected to benefit the preservation of detailed texture.

 

Network architecture of the MWCNN

论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现_第5张图片

The key of MWCNN architecture is to design the CNN block after each level of DWT.

每个CNN block有4层全卷积组成(没有池化),将所有子带图像作为输入。相反,不同的CNN被部署到深度卷积小帧中的低频和高频频带(In contrast, different CNNs are deployed to low-frequency and high-frequency bands in deep convolutional framelets)在DWT之后的子带图像仍然是依赖的

Each layer of the CNN block is composed of convolution with 3 X 3 filters (Conv), batch normalization (BN), and rectified linear unit (ReLU) operations. As to the last layer of the last CNN block, Conv without BN and ReLU is adopted to predict residual image.

MWCNN与U-Net的区别

  1. 对于上采样和下采样。U-Net中采用最大池化和上卷积(up-convolution)。而在MWCNN中采用DWT和IWT
  2. 对于MWCNN,下采样会导致feature map的数量增加。而在U-Net中下采样不会印象特征图的channel数目,而是采用随后的卷积层来增加feature map 的channel
  3. 在MWCNN中,elementwise summation被用于结合来自于收缩网络和扩展网络的特征图。而在传统的U-Net中采用级联

 

关于在pytorch中实现DWT

https://github.com/t-vi/pytorch-tvmisc/blob/master/misc/2D-Wavelet-Transform.ipynb

https://github.com/fbcotter/pytorch_wavelets

https://pytorch-wavelets.readthedocs.io/en/latest/dwt.html

cd /home/guanwp/BasicSR-master/codes/models/modules/
git clone https://github.com/fbcotter/pytorch_wavelets
cd pytorch_wavelets
pip install .
 

或者直接使用Pywalvets

https://blog.csdn.net/nanbei2463776506/article/details/64124841

论文中采用haar小波

 

代码复现

先给出数据集的下载链接

Waterloo Exploration Database (WED) https://ece.uwaterloo.ca/~k29ma/exploration/

Berkeley Segmentation Dataset (BSD) 200 https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU

DIV2K800是之前博文中一直在用的数据集

 

数据预处理

function generate_mod_LR_bic()
%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.

%% set parameters
% comment the unnecessary line
input_folder = '/home/guanwp/BasicSR_datasets/DIV2K800_sub';
% save_mod_folder = '';
%save_LR_folder = '/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4';
save_bic_folder = '/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicubic_X4';

up_scale = 4;
mod_scale = 4;

if exist('save_mod_folder', 'var')
    if exist(save_mod_folder, 'dir')
        disp(['It will cover ', save_mod_folder]);
    else
        mkdir(save_mod_folder);
    end
end
if exist('save_LR_folder', 'var')
    if exist(save_LR_folder, 'dir')
        disp(['It will cover ', save_LR_folder]);
    else
        mkdir(save_LR_folder);
    end
end
if exist('save_bic_folder', 'var')
    if exist(save_bic_folder, 'dir')
        disp(['It will cover ', save_bic_folder]);
    else
        mkdir(save_bic_folder);
    end
end

idx = 0;
filepaths = dir(fullfile(input_folder,'*.*'));
for i = 1 : length(filepaths)
    [paths,imname,ext] = fileparts(filepaths(i).name);
    if isempty(imname)
        disp('Ignore . folder.');
    elseif strcmp(imname, '.')
        disp('Ignore .. folder.');
    else
        idx = idx + 1;
        str_rlt = sprintf('%d\t%s.\n', idx, imname);
        fprintf(str_rlt);
        % read image
        img = imread(fullfile(input_folder, [imname, ext]));
        img = im2double(img);
        % modcrop
        img = modcrop(img, mod_scale);
        if exist('save_mod_folder', 'var')
            imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
        end
        % LR
        im_LR = imresize(img, 1/up_scale, 'bicubic');
        if exist('save_LR_folder', 'var')
            imwrite(im_LR, fullfile(save_LR_folder, [imname, '_bicLRx4.png']));
        end
        % Bicubic
        if exist('save_bic_folder', 'var')
            im_B = imresize(im_LR, up_scale, 'bicubic');
            imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png']));
        end
    end
end
end

%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
    sz = size(img);
    sz = sz - mod(sz, modulo);
    img = img(1:sz(1), 1:sz(2));
else
    tmpsz = size(img);
    sz = tmpsz(1:2);
    sz = sz - mod(sz, modulo);
    img = img(1:sz(1), 1:sz(2),:);
end
end

 

setting

sub512,stride512/2

{
  "name": "MWCNN_DATA" //"001_RRDB_PSNR_x4_DIV2K" //  please remove "debug_" during training or tensorboard wounld not work
  ,
  "use_tb_logger": true,
  "model": "sr",
  //"crop_scale": 0,
   "scale": 1//it must be 1
  ,
  "gpu_ids": [4,5],
  "datasets": {
    "train": {
      "name": "MWCNN_DATA",
      "mode": "LRHR" //it must be this, and the detail would be shown in LRHR_dataset.py
      //, "noise_get": true///
      ,
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/MWCNN_data_sub" ///must be sub
      ,
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/MWCNN_data_sub_bicubic_X4",
      "subset_file": null,
      "use_shuffle": true,
      "n_workers": 8,
      "batch_size": 24//16//32 //how many samples in each iters
      ,
      "HR_size": 128// 128 | 192
      ,
      "use_flip": false //true//
      ,
      "use_rot": false //true
    },
    "val": {
      "name": "Set5",
      "mode": "LRHR",
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_bicubic_X4"
      //, "noise_get": true///this is important
    }
  },
  "path": {
    "root": "/home/guanwp/BasicSR-master/",
    "pretrain_model_G": null,
    "experiments_root": "/home/guanwp/BasicSR-master/experiments/",
    "models": "/home/guanwp/BasicSR-master/experiments/MWCNN_DATA/models",
    "log": "/home/guanwp/BasicSR-master/experiments/MWCNN_DATA",
    "val_images": "/home/guanwp/BasicSR-master/experiments/MWCNN_DATA/val_images"
  },
  "network_G": {
    "which_model_G":"mwcnn"//"noise_estimation" //"espcn"//"srresnet"//"sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet
    ,
    "norm_type": null,
    "mode": "CNA",
    "nf": 64 //56//64
    ,
    "nb": 16,//number of residual block
    "in_nc": 3,
    "out_nc": 3,
    "gc": 32,
    "group": 1
  },
  "train": {
    "lr_G": 1e-3//8e-4 //1e-3//2e-4
    ,
    "lr_scheme": "MultiStepLR",
    "lr_steps": [300000,400000,600000,800000,1000000],
    "lr_gamma": 0.5,
    "pixel_criterion": "l2" //"l2_tv"//"l1"//'l2'//huber//Cross   //should be MSE LOSS
    ,
    "pixel_weight": 1.0,
    "val_freq": 1e3,
    "manual_seed": 0,
    "niter": 1.2e6 //2e6//1e6
  },
  "logger": {
    "print_freq": 200,
    "save_checkpoint_freq": 1e3
  }
}
{
  "name": "MWCNN_DIVIK" //"001_RRDB_PSNR_x4_DIV2K" //  please remove "debug_" during training or tensorboard wounld not work
  ,
  "use_tb_logger": true,
  "model": "sr",
  //"crop_scale": 0,
   "scale": 1//it must be 1
  ,
  "gpu_ids": [4,5],
  "datasets": {
    "train": {
      "name": "DIV2K80",
      "mode": "LRHR" //it must be this, and the detail would be shown in LRHR_dataset.py
      //, "noise_get": true///
      ,
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub" ///must be sub
      ,
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicubic_X4",
      "subset_file": null,
      "use_shuffle": true,
      "n_workers": 8,
      "batch_size": 16//32 //how many samples in each iters
      ,
      "HR_size": 128// 128 | 192
      ,
      "use_flip": false //true//
      ,
      "use_rot": false //true
    },
    "val": {
      "name": "Set5",
      "mode": "LRHR",
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_bicubic_X4"
      //, "noise_get": true///this is important
    }
  },
  "path": {
    "root": "/home/guanwp/BasicSR-master/",
    "pretrain_model_G": null,
    "experiments_root": "/home/guanwp/BasicSR-master/experiments/",
    "models": "/home/guanwp/BasicSR-master/experiments/MWCNN_DIVIK/models",
    "log": "/home/guanwp/BasicSR-master/experiments/MWCNN_DIVIK",
    "val_images": "/home/guanwp/BasicSR-master/experiments/MWCNN_DIVIK/val_images"
  },
  "network_G": {
    "which_model_G":"mwcnn"//"noise_estimation" //"espcn"//"srresnet"//"sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet
    ,
    "norm_type": null,
    "mode": "CNA",
    "nf": 64 //56//64
    ,
    "nb": 16,//number of residual block
    "in_nc": 3,
    "out_nc": 3,
    "gc": 32,
    "group": 1
  },
  "train": {
    "lr_G": 1e-3//8e-4 //1e-3//2e-4
    ,
    "lr_scheme": "MultiStepLR",
    "lr_steps": [300000,400000,600000,800000,1000000],
    "lr_gamma": 0.5,
    "pixel_criterion": "l2" //"l2_tv"//"l1"//'l2'//huber//Cross   //should be MSE LOSS
    ,
    "pixel_weight": 1.0,
    "val_freq": 1e3,
    "manual_seed": 0,
    "niter": 1.2e6 //2e6//1e6
  },
  "logger": {
    "print_freq": 200,
    "save_checkpoint_freq": 1e3
  }
}

网络结构

在network中加入

#############################################################################################################
    elif which_model=='mwcnn':#MWCNN
        netG=arch.MWCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')
#############################################################################################################

网络结构如下

#######################################################################################################3
class Block_of_DMT1(nn.Module):
    def __init__(self):
        super(Block_of_DMT1,self).__init__()

        #DMT1
        self.conv1_1=nn.Conv2d(in_channels=160,out_channels=160,kernel_size=3,stride=1,padding=1)
        self.bn1_1=nn.BatchNorm2d(160, affine=True)
        self.relu1_1=nn.ReLU()

    def forward(self, x):
        output = self.relu1_1(self.bn1_1(self.conv1_1(x)))
        return output 

class Block_of_DMT2(nn.Module):
    def __init__(self):
        super(Block_of_DMT2,self).__init__()

        #DMT1
        self.conv2_1=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1)
        self.bn2_1=nn.BatchNorm2d(256, affine=True)
        self.relu2_1=nn.ReLU()

    def forward(self, x):
        output = self.relu2_1(self.bn2_1(self.conv2_1(x)))
        return output 

class Block_of_DMT3(nn.Module):
    def __init__(self):
        super(Block_of_DMT3,self).__init__()

        #DMT1
        self.conv3_1=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1)
        self.bn3_1=nn.BatchNorm2d(256, affine=True)
        self.relu3_1=nn.ReLU()

    def forward(self, x):
        output = self.relu3_1(self.bn3_1(self.conv3_1(x)))
        return output 


#MWCNN
class MWCNN(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, upscale=2, norm_type='batch', act_type='relu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):##play attention the upscales
        super(MWCNN,self).__init__()
        
        self.DWT= DWTForward(J=1, wave='haar').cuda() 
        self.IDWT=DWTInverse(wave='haar').cuda()

        #DMT1 operation
        #DMT1
        self.conv_DMT1=nn.Conv2d(in_channels=3*4,out_channels=160,kernel_size=3,stride=1,padding=1)
        self.bn_DMT1=nn.BatchNorm2d(160, affine=True)
        self.relu_DMT1=nn.ReLU()
        #IDMT1
        self.conv_IDMT1=nn.Conv2d(in_channels=160,out_channels=3*4,kernel_size=3,stride=1,padding=1)


        self.blockDMT1=self.make_layer(Block_of_DMT1,3)

        #DMT2 operation
        #DMT2
        self.conv_DMT2=nn.Conv2d(in_channels=640,out_channels=256,kernel_size=3,stride=1,padding=1)
        self.bn_DMT2=nn.BatchNorm2d(256, affine=True)
        self.relu_DMT2=nn.ReLU()
        #IDMT2
        self.conv_IDMT2=nn.Conv2d(in_channels=256,out_channels=640,kernel_size=3,stride=1,padding=1)
        self.bn_IDMT2=nn.BatchNorm2d(640, affine=True)
        self.relu_IDMT2=nn.ReLU()

        self.blockDMT2=self.make_layer(Block_of_DMT2,3)

        #DMT3 operation
        #DMT3
        self.conv_DMT3=nn.Conv2d(in_channels=1024,out_channels=256,kernel_size=3,stride=1,padding=1)
        self.bn_DMT3=nn.BatchNorm2d(256, affine=True)
        self.relu_DMT3=nn.ReLU()
        #IDMT3
        self.conv_IDMT3=nn.Conv2d(in_channels=256,out_channels=1024,kernel_size=3,stride=1,padding=1)
        self.bn_IDMT3=nn.BatchNorm2d(1024, affine=True)
        self.relu_IDMT3=nn.ReLU()

        self.blockDMT3=self.make_layer(Block_of_DMT3,3)



    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def _transformer(self, DMT1_yl, DMT1_yh):
        list_tensor = []
        for i in range(3):
            list_tensor.append(DMT1_yh[0][:,:,i,:,:])
        list_tensor.append(DMT1_yl)
        return torch.cat(list_tensor, 1)


    def _Itransformer(self,out):
        #w = pywt.Wavelet('haar')
        yh = []
        C=out.shape[1]/4
        #sz=2*(len(w.dec_lo) // 2 - 1)
        #if yl.shape[-2] % 2 == 1 and yl.shape[-1] % 2 == 1:
            #yl = F.pad(yl, (sz, sz+1, sz, sz+1), mode='reflect')
        #elif yl.shape[-2] % 2 == 1:
            #yl = F.pad(yl, (sz, sz+1, sz, sz), mode='reflect')
        #elif yl.shape[-1] % 2 == 1:
            #yl = F.pad(yl, (sz, sz, sz, sz+1), mode='reflect')
        #else:
            #yl = F.pad(yl, (sz, sz, sz, sz), mode='reflect')


        y = out.reshape((out.shape[0], C, 4, out.shape[-2], out.shape[-1]))
        yl = y[:,:,0].contiguous()
        yh.append(y[:,:,1:].contiguous())

        return yl, yh

    def forward(self, x):#
         DMT1_p=x
         #DMT1
         DMT1_yl,DMT1_yh = self.DWT(x)
         DMT1 = self._transformer(DMT1_yl, DMT1_yh)
         out=self.relu_DMT1(self.bn_DMT1(self.conv_DMT1(DMT1)))
         out=self.blockDMT1(out)###160
         
         DMT2_p=out
         #DMT2
         DMT2_yl, DMT2_yh=self.DWT(out)
         DMT2=self._transformer(DMT2_yl, DMT2_yh)
         out=self.relu_DMT2(self.bn_DMT2(self.conv_DMT2(DMT2)))
         out=self.blockDMT2(out)###256

         DMT3_p=out
         #DMT3
         DMT3_yl, DMT3_yh=self.DWT(out)
         DMT3=self._transformer(DMT3_yl, DMT3_yh)
         out=self.relu_DMT3(self.bn_DMT3(self.conv_DMT3(DMT3)))
         out=self.blockDMT3(out)###256

         #IDMT3
         out=self.blockDMT3(out)#DMT4
         out=self.relu_IDMT3(self.bn_IDMT3(self.conv_IDMT3(out)))
         out=self._Itransformer(out)###########
         IDMT3=self.IDWT(out)
         out=IDMT3+DMT3_p

         #IDMT2
         out=self.blockDMT2(out)
         out=self.relu_IDMT2(self.bn_IDMT2(self.conv_IDMT2(out)))
         out=self._Itransformer(out)##############
         IDMT2=self.IDWT(out)
         out=IDMT2+DMT2_p

         #IDMT1
         out=self.blockDMT1(out)
         out=self.conv_IDMT1(out)
         out=self._Itransformer(out)###############
         IDMT1=self.IDWT(out)
         out=IDMT1+DMT1_p

         return out
 
##########################################

 

训练结果

 

 

 

测试效果

 

 

 

 

 

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号