JosePhilo


【深度学习】DnCNN 图像盲降噪与优化算法对比

Category: 杂项
Tag: none
Written by Joseph with ♥ on June 22, 2021

本文目录 [隐藏]
  • I. DnCNNs with ReLU/Leaky_ReLU Activation
    • Raw DnCNNs with ReLU Activation
    • "Improved?" DnCNNs with Leaky_ReLU Activation
  • II. 1. 下载数据集
  • III. 2. 训练集加噪
  • IV. 3. DnCNN
    • CNN 网络为啥要带个权重
    • PSNR
    • DnCNN 炼丹
    • 效果
    • DnCNN 网络参数
  • V. 4. UDnCNN
    • UDnCNN 炼丹
    • UDnCNN 网络参数
    • 比比 DnCNN 和 UDnCNN
  • VI. 5. DUDnCNN
    • DUDnCNN 炼丹
    • 比较 DnCNN UDnCNN DUDnCNN
    • DUDnCNN 网络参数
  • VII. 1. 下载数据集
  • VIII. 2. 训练集加噪
  • IX. 3. DnCNN
    • CNN 网络为啥要带个权重
    • PSNR
    • DnCNN 炼丹
    • 效果
    • DnCNN 网络参数
  • X. 4. UDnCNN
    • UDnCNN 炼丹
    • UDnCNN 网络参数
    • 比比 DnCNN 和 UDnCNN
  • XI. 5. DUDnCNN
    • DUDnCNN 炼丹
    • 比较 DnCNN UDnCNN DUDnCNN
    • DUDnCNN 网络参数

ReLU
Open In Colab Leaky_ReLU
Open In Colab

I. DnCNNs with ReLU/Leaky_ReLU Activation

Models trained with CUDA, batch = 4, after 200 epochs each.

No statistically significant improvement with Leaky_ReLU in replace of ReLU, period.

Raw DnCNNs with ReLU Activation

Eval / CNNsDnCNNUDnCNNDUDnCNN
PSNR29.076228.419629.3118
Loss0.0051083810.0059017280.004859959

1

"Improved?" DnCNNs with Leaky_ReLU Activation

Eval / CNNsDnCNNUDnCNNDUDnCNN
PSNR29.0926 ↑28.3110 ↓29.1659 ↓
Loss0.005087597 ↓0.006038793 ↑0.005011519 ↑

2


DnCNN UDnCNN DUDnCNN

II. 1. 下载数据集

!wget -N https://raw.githubusercontent.com/eebowen/Transfer-Learning-and-Deep-Neural-Network-Acceleration-for-Image-Classification/master/nntools.py
!wget -N https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz
!tar -zxvf BSDS300-images.tgz
dataset_root_dir = './BSDS300/images/'
--2021-06-17 15:11:46--  https://raw.githubusercontent.com/eebowen/Transfer-Learning-and-Deep-Neural-Network-Acceleration-for-Image-Classification/master/nntools.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12781 (12K) [text/plain]
Saving to: ‘nntools.py’

nntools.py          100%[===================>]  12.48K  --.-KB/s    in 0s      

Last-modified header missing -- time-stamps turned off.
2021-06-17 15:11:46 (51.1 MB/s) - ‘nntools.py’ saved [12781/12781]

--2021-06-17 15:11:46--  https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz
Resolving www2.eecs.berkeley.edu (www2.eecs.berkeley.edu)... 128.32.244.190
Connecting to www2.eecs.berkeley.edu (www2.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 22211847 (21M) [application/x-tar]
Saving to: ‘BSDS300-images.tgz’

BSDS300-images.tgz  100%[===================>]  21.18M  3.31MB/s    in 5.6s    

2021-06-17 15:11:53 (3.76 MB/s) - ‘BSDS300-images.tgz’ saved [22211847/22211847]



%matplotlib inline

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td
import torchvision as tv
from PIL import Image
import matplotlib.pyplot as plt
import nntools as nt
import time
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
cuda:0

III. 2. 训练集加噪

$σ = 30$ 高斯噪音 $180 × 180$ 裁剪(左上角或随机位置)

class NoisyBSDSDataset(td.Dataset):

    def __init__(self, root_dir, mode='train', image_size=(180, 180), sigma=30):
        super(NoisyBSDSDataset, self).__init__()
        self.mode = mode
        self.image_size = image_size
        self.sigma = sigma
        self.images_dir = os.path.join(root_dir, mode)
        self.files = os.listdir(self.images_dir)

    def __len__(self):
        return len(self.files)

    def __repr__(self):
        return "NoisyBSDSDataset(mode={}, image_size={}, sigma={})". \
            format(self.mode, self.image_size, self.sigma)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.files[idx])
        clean = Image.open(img_path).convert('RGB')   
        # 随机裁剪
        #i = np.random.randint(clean.size[0] - self.image_size[0])
        #j = np.random.randint(clean.size[1] - self.image_size[1])
        i=0
        j=0
        clean = clean.crop([i, j, i+self.image_size[0], j+self.image_size[1]])
        transform = tv.transforms.Compose([
            # 转换张量
            tv.transforms.ToTensor(),
            # [−1, 1]
            tv.transforms.Normalize((.5, .5, .5), (.5, .5, .5))
            ])
        clean = transform(clean)
        
        noisy = clean + 2 / 255 * self.sigma * torch.randn(clean.shape)
        return noisy, clean
def myimshow(image, ax=plt):
    image = image.to('cpu').numpy()
    image = np.moveaxis(image, [0, 1, 2], [2, 0, 1])
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    h = ax.imshow(image)
    ax.axis('off')
    return h

导训练集和测试集进来

train_set = NoisyBSDSDataset(dataset_root_dir)
test_set = NoisyBSDSDataset(dataset_root_dir, mode='test', image_size=(320, 320))
x = test_set[0]fig, axes = plt.subplots(ncols=2)myimshow(x[0], ax=axes[0])axes[0].set_title('Noisy')myimshow(x[1], ax=axes[1])axes[1].set_title('Clean')print(f'image size is {x[0].shape}.')
image size is torch.Size([3, 320, 320]).



png

IV. 3. DnCNN

loss 用的均方差

class NNRegressor(nt.NeuralNetwork):    def __init__(self):        super(NNRegressor, self).__init__()        self.mse = nn.MSELoss()    def criterion(self, y, d):        return self.mse(y, d)

CNN 网络为啥要带个权重

看这个台湾人写的比较清楚
深度學習: Weight initialization和Batch Normalization

无权初始化

class DnCNN(NNRegressor):    def __init__(self, D, C=64):        super(DnCNN, self).__init__()        self.D = D                self.conv = nn.ModuleList()        self.conv.append(nn.Conv2d(3, C, 3, padding=1))        self.conv.extend([nn.Conv2d(C, C, 3, padding=1) for _ in range(D)])        self.conv.append(nn.Conv2d(C, 3, 3, padding=1))                self.bn = nn.ModuleList()        for k in range(D):            self.bn.append(nn.BatchNorm2d(C, C))    def forward(self, x):        D = self.D        h = F.relu(self.conv[0](x))        for i in range(D):            h = F.relu(self.bn[i](self.conv[i+1](h)))        y = self.conv[D+1](h) + x        return y

零填充(泛卷积)对输入图像矩阵的边缘进行滤波,是玄学

x, _ = train_set[-1]x = x.unsqueeze(0).to(device)Ds = [0, 1, 2, 4, 8]fig, axes = plt.subplots(nrows=len(Ds), ncols=3, figsize=(9,9))for i in range(len(Ds)):    with torch.no_grad():        model = DnCNN(Ds[i]).to(device)        y = model.forward(x) # 4-d    # 3-d    myimshow(x[0], ax=axes[i][0])    axes[i][0].set_title('x[0]')    myimshow(y[0], ax=axes[i][1])    axes[i][1].set_title(f'y[0] (D={Ds[i]})')    myimshow(x[0]-y[0], ax=axes[i][2])    axes[i][2].set_title(f'x[0]-y[0] (D={Ds[i]})')

​
png
​

D=0 才有残差输出,梯度消失,没法炼丹

带权跑一下

class DnCNN(NNRegressor):    def __init__(self, D, C=64):        super(DnCNN, self).__init__()        self.D = D                self.conv = nn.ModuleList()        self.conv.append(nn.Conv2d(3, C, 3, padding=1))        self.conv.extend([nn.Conv2d(C, C, 3, padding=1) for _ in range(D)])        self.conv.append(nn.Conv2d(C, 3, 3, padding=1))        # Kaiming正态分布初始化,又叫啥He('s) initialization        for i in range(len(self.conv[:-1])):            nn.init.kaiming_normal_(self.conv[i].weight.data, nonlinearity='relu')                # Batch norm        self.bn = nn.ModuleList()        self.bn.extend([nn.BatchNorm2d(C, C) for _ in range(D)])        # Batch norm layer 初始化权值        for i in range(D):            nn.init.constant_(self.bn[i].weight.data, 1.25 * np.sqrt(C))    def forward(self, x):        D = self.D        h = F.relu(self.conv[0](x))        for i in range(D):            h = F.relu(self.bn[i](self.conv[i+1](h)))        y = self.conv[D+1](h) + x        return y
x, _ = train_set[-1]x = x.unsqueeze(0).to(device)Ds = [0, 1, 2, 4, 8]fig, axes = plt.subplots(nrows=len(Ds), ncols=3, figsize=(9,9))for i in range(len(Ds)):    with torch.no_grad():        model = DnCNN(Ds[i]).to(device)        y = model.forward(x)        myimshow(x[0], ax=axes[i][0])    axes[i][0].set_title('x[0]')    myimshow(y[0], ax=axes[i][1])    axes[i][1].set_title(f'y[0] (D={Ds[i]})')    myimshow(x[0]-y[0], ax=axes[i][2])    axes[i][2].set_title(f'x[0]-y[0] (D={Ds[i]})')

​
png
​

然后残差就非零,能梯度下降,能炼丹

PSNR

峰值信噪比 PSNR (Peak Signal-to-Noise-Ratio),值域是 [−1, 1]

$$ PSNR = 10\log_{10}\frac{4n}{\Vert y-d\Vert_2^2} $$

d 是理想值, y 是估计值,分母就是均方差, n 是张量大小,对数定义,单位是 dB,数越大越好

这东西是要算平均的

class DenoisingStatsManager(nt.StatsManager):    def __init__(self):        super(DenoisingStatsManager, self).__init__()    def init(self):        super(DenoisingStatsManager, self).init()        self.running_psnr = 0    def accumulate(self, loss, x, y, d):        super(DenoisingStatsManager, self).accumulate(loss, x, y, d)            n = x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3]        self.running_psnr += 10*torch.log10(4*n/(torch.norm(y-d)**2))    def summarize(self):        loss = super(DenoisingStatsManager, self).summarize()        psnr = self.running_psnr / self.number_update        return {'loss': loss, 'PSNR': psnr.cpu()}
def plot(exp, fig, axes, noisy, visu_rate=2):    if exp.epoch % visu_rate != 0:        return    with torch.no_grad():        denoised = exp.net(noisy[None].to(net.device))[0]    axes[0][0].clear()    axes[0][1].clear()    axes[1][0].clear()    axes[1][1].clear()    myimshow(noisy, ax=axes[0][0])    axes[0][0].set_title('Noisy image')        myimshow(denoised, ax=axes[0][1])    axes[0][1].set_title('Denoised image')        axes[1][0].plot([exp.history[k][0]['loss'] for k in range(exp.epoch)], label='training loss')    axes[1][0].set_ylabel('Loss')    axes[1][0].set_xlabel('Epoch')    axes[1][0].legend()        axes[1][1].plot([exp.history[k][0]['PSNR'] for k in range(exp.epoch)], label='training psnr')    axes[1][1].set_ylabel('PSNR')    axes[1][1].set_xlabel('Epoch')    axes[1][1].legend()        plt.tight_layout()    fig.canvas.draw()

DnCNN 炼丹

lr = 1e-3net = DnCNN(6).to(device)adam = torch.optim.Adam(net.parameters(), lr=lr)stats_manager = DenoisingStatsManager()exp1 = nt.Experiment(net, train_set, test_set, adam, stats_manager, batch_size=4,                output_dir="./checkpoints/denoising1", perform_validation_during_training=True)
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(9, 7))exp1.run(num_epochs=200, plot=lambda exp: plot(exp, fig=fig, axes=axes,                                                noisy=test_set[0][0]))
Start/Continue training from epoch 0Epoch 1 (Time: 18.20s)Epoch 2 (Time: 18.09s)Epoch 3 (Time: 17.89s)Epoch 4 (Time: 17.89s)Epoch 5 (Time: 17.88s)Epoch 6 (Time: 17.84s)Epoch 7 (Time: 17.83s)Epoch 8 (Time: 17.88s)Epoch 9 (Time: 17.84s)Epoch 10 (Time: 17.81s)Epoch 11 (Time: 17.92s)Epoch 12 (Time: 17.87s)Epoch 13 (Time: 17.91s)Epoch 14 (Time: 17.89s)Epoch 15 (Time: 17.82s)Epoch 16 (Time: 17.83s)Epoch 17 (Time: 17.86s)Epoch 18 (Time: 17.82s)Epoch 19 (Time: 17.82s)Epoch 20 (Time: 17.88s)Epoch 21 (Time: 17.89s)Epoch 22 (Time: 17.88s)Epoch 23 (Time: 17.84s)Epoch 24 (Time: 17.85s)Epoch 25 (Time: 17.83s)Epoch 26 (Time: 17.90s)Epoch 27 (Time: 17.76s)Epoch 28 (Time: 17.90s)Epoch 29 (Time: 17.89s)Epoch 30 (Time: 17.90s)Epoch 31 (Time: 17.87s)Epoch 32 (Time: 17.90s)Epoch 33 (Time: 17.85s)Epoch 34 (Time: 17.89s)Epoch 35 (Time: 17.84s)Epoch 36 (Time: 17.85s)Epoch 37 (Time: 17.86s)Epoch 38 (Time: 17.88s)Epoch 39 (Time: 17.83s)Epoch 40 (Time: 17.77s)Epoch 41 (Time: 17.77s)Epoch 42 (Time: 17.74s)Epoch 43 (Time: 17.80s)Epoch 44 (Time: 17.87s)Epoch 45 (Time: 17.84s)Epoch 46 (Time: 17.82s)Epoch 47 (Time: 17.78s)Epoch 48 (Time: 17.79s)Epoch 49 (Time: 17.81s)Epoch 50 (Time: 17.80s)Epoch 51 (Time: 17.78s)Epoch 52 (Time: 17.78s)Epoch 53 (Time: 17.81s)Epoch 54 (Time: 17.79s)Epoch 55 (Time: 17.81s)Epoch 56 (Time: 17.77s)Epoch 57 (Time: 17.76s)Epoch 58 (Time: 17.78s)Epoch 59 (Time: 17.77s)Epoch 60 (Time: 17.77s)Epoch 61 (Time: 17.80s)Epoch 62 (Time: 17.80s)Epoch 63 (Time: 17.78s)Epoch 64 (Time: 17.80s)Epoch 65 (Time: 17.81s)Epoch 66 (Time: 17.69s)Epoch 67 (Time: 17.78s)Epoch 68 (Time: 17.76s)Epoch 69 (Time: 17.78s)Epoch 70 (Time: 17.79s)Epoch 71 (Time: 17.75s)Epoch 72 (Time: 17.81s)Epoch 73 (Time: 17.83s)Epoch 74 (Time: 17.81s)Epoch 75 (Time: 17.81s)Epoch 76 (Time: 17.85s)Epoch 77 (Time: 17.84s)Epoch 78 (Time: 17.85s)Epoch 79 (Time: 17.83s)Epoch 80 (Time: 17.84s)Epoch 81 (Time: 17.80s)Epoch 82 (Time: 17.86s)Epoch 83 (Time: 17.82s)Epoch 84 (Time: 17.84s)Epoch 85 (Time: 17.78s)Epoch 86 (Time: 17.88s)Epoch 87 (Time: 17.85s)Epoch 88 (Time: 17.84s)Epoch 89 (Time: 17.84s)Epoch 90 (Time: 17.73s)Epoch 91 (Time: 17.86s)Epoch 92 (Time: 17.85s)Epoch 93 (Time: 17.83s)Epoch 94 (Time: 17.80s)Epoch 95 (Time: 17.78s)Epoch 96 (Time: 17.84s)Epoch 97 (Time: 17.80s)Epoch 98 (Time: 17.82s)Epoch 99 (Time: 17.81s)Epoch 100 (Time: 17.82s)Epoch 101 (Time: 17.83s)Epoch 102 (Time: 17.82s)Epoch 103 (Time: 17.80s)Epoch 104 (Time: 17.82s)Epoch 105 (Time: 17.79s)Epoch 106 (Time: 17.83s)Epoch 107 (Time: 17.79s)Epoch 108 (Time: 17.81s)Epoch 109 (Time: 17.83s)Epoch 110 (Time: 17.83s)Epoch 111 (Time: 17.79s)Epoch 112 (Time: 17.80s)Epoch 113 (Time: 17.83s)Epoch 114 (Time: 17.82s)Epoch 115 (Time: 17.80s)Epoch 116 (Time: 17.80s)Epoch 117 (Time: 17.81s)Epoch 118 (Time: 17.82s)Epoch 119 (Time: 17.79s)Epoch 120 (Time: 17.80s)Epoch 121 (Time: 17.85s)Epoch 122 (Time: 17.81s)Epoch 123 (Time: 17.79s)Epoch 124 (Time: 17.79s)Epoch 125 (Time: 17.84s)Epoch 126 (Time: 17.82s)Epoch 127 (Time: 17.78s)Epoch 128 (Time: 17.85s)Epoch 129 (Time: 17.79s)Epoch 130 (Time: 17.78s)Epoch 131 (Time: 17.81s)Epoch 132 (Time: 17.83s)Epoch 133 (Time: 17.81s)Epoch 134 (Time: 17.80s)Epoch 135 (Time: 17.82s)Epoch 136 (Time: 17.82s)Epoch 137 (Time: 17.78s)Epoch 138 (Time: 17.84s)Epoch 139 (Time: 17.82s)Epoch 140 (Time: 17.84s)Epoch 141 (Time: 17.79s)Epoch 142 (Time: 17.84s)Epoch 143 (Time: 17.78s)Epoch 144 (Time: 17.86s)Epoch 145 (Time: 17.82s)Epoch 146 (Time: 17.79s)Epoch 147 (Time: 17.77s)Epoch 148 (Time: 17.83s)Epoch 149 (Time: 17.83s)Epoch 150 (Time: 17.86s)Epoch 151 (Time: 17.80s)Epoch 152 (Time: 17.83s)Epoch 153 (Time: 17.80s)Epoch 154 (Time: 17.84s)Epoch 155 (Time: 17.85s)Epoch 156 (Time: 17.78s)Epoch 157 (Time: 17.80s)Epoch 158 (Time: 17.84s)Epoch 159 (Time: 17.80s)Epoch 160 (Time: 17.78s)Epoch 161 (Time: 17.75s)Epoch 162 (Time: 17.80s)Epoch 163 (Time: 17.76s)Epoch 164 (Time: 17.79s)Epoch 165 (Time: 17.80s)Epoch 166 (Time: 17.85s)Epoch 167 (Time: 17.76s)Epoch 168 (Time: 17.81s)Epoch 169 (Time: 17.79s)Epoch 170 (Time: 17.83s)Epoch 171 (Time: 17.76s)Epoch 172 (Time: 17.81s)Epoch 173 (Time: 17.81s)Epoch 174 (Time: 17.79s)Epoch 175 (Time: 17.79s)Epoch 176 (Time: 17.84s)Epoch 177 (Time: 17.77s)Epoch 178 (Time: 17.78s)Epoch 179 (Time: 17.77s)Epoch 180 (Time: 17.83s)Epoch 181 (Time: 17.77s)Epoch 182 (Time: 17.78s)Epoch 183 (Time: 17.77s)Epoch 184 (Time: 17.77s)Epoch 185 (Time: 17.71s)Epoch 186 (Time: 17.80s)Epoch 187 (Time: 17.81s)Epoch 188 (Time: 17.83s)Epoch 189 (Time: 17.82s)Epoch 190 (Time: 17.79s)Epoch 191 (Time: 17.82s)Epoch 192 (Time: 17.82s)Epoch 193 (Time: 17.82s)Epoch 194 (Time: 17.79s)Epoch 195 (Time: 17.78s)Epoch 196 (Time: 17.86s)Epoch 197 (Time: 17.79s)Epoch 198 (Time: 17.79s)Epoch 199 (Time: 17.82s)Epoch 200 (Time: 17.84s)Finish training for 200 epochs



png

效果

img = []model = exp1.net.to(device)titles = ['clean', 'noise', 'DnCNN']x, clean = test_set[0]x = x.unsqueeze(0).to(device)img.append(clean)img.append(x[0])model.eval()with torch.no_grad():    y = model.forward(x)img.append(y[0])    fig, axes = plt.subplots(ncols=3, figsize=(20,10), sharex='all', sharey='all')for i in range(len(img)):    myimshow(img[i], ax=axes[i])    axes[i].set_title(f'{titles[i]}')

​
png
​

仍存在噪点,有信息缺失

DnCNN 网络参数

for name, param in model.named_parameters():    print(name, param.size(), param.requires_grad)
conv.0.weight torch.Size([64, 3, 3, 3]) Trueconv.0.bias torch.Size([64]) Trueconv.1.weight torch.Size([64, 64, 3, 3]) Trueconv.1.bias torch.Size([64]) Trueconv.2.weight torch.Size([64, 64, 3, 3]) Trueconv.2.bias torch.Size([64]) Trueconv.3.weight torch.Size([64, 64, 3, 3]) Trueconv.3.bias torch.Size([64]) Trueconv.4.weight torch.Size([64, 64, 3, 3]) Trueconv.4.bias torch.Size([64]) Trueconv.5.weight torch.Size([64, 64, 3, 3]) Trueconv.5.bias torch.Size([64]) Trueconv.6.weight torch.Size([64, 64, 3, 3]) Trueconv.6.bias torch.Size([64]) Trueconv.7.weight torch.Size([3, 64, 3, 3]) Trueconv.7.bias torch.Size([3]) Truebn.0.weight torch.Size([64]) Truebn.0.bias torch.Size([64]) Truebn.1.weight torch.Size([64]) Truebn.1.bias torch.Size([64]) Truebn.2.weight torch.Size([64]) Truebn.2.bias torch.Size([64]) Truebn.3.weight torch.Size([64]) Truebn.3.bias torch.Size([64]) Truebn.4.weight torch.Size([64]) Truebn.4.bias torch.Size([64]) Truebn.5.weight torch.Size([64]) Truebn.5.bias torch.Size([64]) True

参数个数

第一层有 64 x 3 x 3 x 3 个 parameter. D 层则有 64 x 64 x 3 x 3 x D. 最后一层是 3 x 64 x 3 x 3 . 总共 3456 + 36864 x D

感受野 (Receptive Field) 计算:

没有池化层,每层固定增加 $2^{0-0+1}=2$,初始输入层是 1 ,小学奥数得到 $R_D=(1+2\times (D+2))^2$.

D=6 为例, $R_6=17^2$.

据说(待考证)σ = 30 高斯噪声下的降噪,单个像素应受到 33 × 33 个像素影响,据此来确定深度。

感受野 $R_D=(1+2\times (D+2)) \times (1+2\times (D+2))$, 令等于33 得到 $D=14$,参数个数 $3456 + 36864 \times 14 = 519552$。

V. 4. UDnCNN

U-net like CNNs

class UDnCNN(NNRegressor):    def __init__(self, D, C=64):        super(UDnCNN, self).__init__()        self.D = D                self.conv = nn.ModuleList()        self.conv.append(nn.Conv2d(3, C, 3, padding=1))        self.conv.extend([nn.Conv2d(C, C, 3, padding=1) for _ in range(D)])        self.conv.append(nn.Conv2d(C, 3, 3, padding=1))        # Kaiming正态分布初始化,又叫啥He('s) initialization        for i in range(len(self.conv[:-1])):            nn.init.kaiming_normal_(self.conv[i].weight.data, nonlinearity='relu')                # batch norm        self.bn = nn.ModuleList()        self.bn.extend([nn.BatchNorm2d(C, C) for _ in range(D)])        # Batch norm layer 初始化权值        for i in range(D):            nn.init.constant_(self.bn[i].weight.data, 1.25 * np.sqrt(C))    # 前面都一样,这里搞个U-Net    def forward(self, x):        D = self.D        h = F.relu(self.conv[0](x))        h_buff = []        idx_buff = []        shape_buff = []        for i in range(D//2-1):            shape_buff.append(h.shape)            h, idx = F.max_pool2d(F.relu(self.bn[i](self.conv[i+1](h))),                                   kernel_size=(2,2), return_indices=True)            h_buff.append(h)            idx_buff.append(idx)        for i in range(D//2-1, D//2+1):            h = F.relu(self.bn[i](self.conv[i+1](h)))        for i in range(D//2+1, D):            j = i - (D//2 + 1) + 1            h = F.max_unpool2d(F.relu(self.bn[i](self.conv[i+1]((h+h_buff[-j])/np.sqrt(2)))),                                idx_buff[-j], kernel_size=(2,2), output_size=shape_buff[-j])        y = self.conv[D+1](h) + x        return y

UDnCNN 炼丹

lr = 1e-3net = UDnCNN(6).to(device)adam = torch.optim.Adam(net.parameters(), lr=lr)stats_manager = DenoisingStatsManager()exp2 = nt.Experiment(net, train_set, test_set, adam, stats_manager, batch_size=4,                output_dir="./checkpoints/denoising2", perform_validation_during_training=True)
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(9, 7))exp2.run(num_epochs=200, plot=lambda exp: plot(exp, fig=fig, axes=axes,                                                noisy=test_set[0][0]))
Start/Continue training from epoch 200Finish training for 200 epochs



png

UDnCNN 网络参数

for name, param in exp2.net.named_parameters():    print(name, param.size(), param.requires_grad)
conv.0.weight torch.Size([64, 3, 3, 3]) Trueconv.0.bias torch.Size([64]) Trueconv.1.weight torch.Size([64, 64, 3, 3]) Trueconv.1.bias torch.Size([64]) Trueconv.2.weight torch.Size([64, 64, 3, 3]) Trueconv.2.bias torch.Size([64]) Trueconv.3.weight torch.Size([64, 64, 3, 3]) Trueconv.3.bias torch.Size([64]) Trueconv.4.weight torch.Size([64, 64, 3, 3]) Trueconv.4.bias torch.Size([64]) Trueconv.5.weight torch.Size([64, 64, 3, 3]) Trueconv.5.bias torch.Size([64]) Trueconv.6.weight torch.Size([64, 64, 3, 3]) Trueconv.6.bias torch.Size([64]) Trueconv.7.weight torch.Size([3, 64, 3, 3]) Trueconv.7.bias torch.Size([3]) Truebn.0.weight torch.Size([64]) Truebn.0.bias torch.Size([64]) Truebn.1.weight torch.Size([64]) Truebn.1.bias torch.Size([64]) Truebn.2.weight torch.Size([64]) Truebn.2.bias torch.Size([64]) Truebn.3.weight torch.Size([64]) Truebn.3.bias torch.Size([64]) Truebn.4.weight torch.Size([64]) Truebn.4.bias torch.Size([64]) Truebn.5.weight torch.Size([64]) Truebn.5.bias torch.Size([64]) True

池化不改变参数个数,还是 3456 + 36864 x D.

感受野 (Receptive Field) 计算:

$R_D=(1+\sum_{i=1}^{D/2}2^i+2\times 2^{D/2}+\sum_{i=1}^{D/2-1}2^i+2)^2$.

D=6 为例 $R_6(1+(2+4+8)+(2\times 8)+(4+2)+2)^2=39^2$.

从 PSNR 看,加了 U-Net 更烂了... 池化会丢失价值特征信息,但图上看起来更“好看”了...

比比 DnCNN 和 UDnCNN

# DnCNNexp1.evaluate()
{'PSNR': tensor(29.0716), 'loss': 0.0051106692105531695}



# UDnCNNexp2.evaluate()
{'PSNR': tensor(28.4064), 'loss': 0.0059140139166265725}



img = []titles = ['clean', 'noise', 'DnCNN','UDnCNN']x, clean = test_set[0]x = x.unsqueeze(0).to(device)img.append(clean)img.append(x[0])model = exp1.net.to(device)model.eval()with torch.no_grad():    y = model.forward(x)img.append(y[0])model = exp2.net.to(device)model.eval()with torch.no_grad():    y = model.forward(x)img.append(y[0])    fig, axes = plt.subplots(ncols=4, figsize=(20,10), sharex='all', sharey='all')for i in range(len(img)):    myimshow(img[i], ax=axes[i])    axes[i].set_title(f'{titles[i]}')

​
png
​

VI. 5. DUDnCNN

U-net like CNNs with dilated convolutions

空洞卷积(dilated convolution)代替池化来增大感受野(Receptive Field)

然而 pytorch 空洞卷积跑的贼慢,原理上看应该和普通卷积差不多快,这里有个优化的问题...

空洞卷积之前
torch.backends.cudnn.benchmark=True 之后改回 torch.backends.cudnn.benchmark=False 可以提速,详见 https://github.com/pytorch/pytorch/issues/15054.

class DUDnCNN(NNRegressor):    def __init__(self, D, C=64):        super(DUDnCNN, self).__init__()        self.D = D                # compute k(max_pool) and l(max_unpool)        k = [0]        k.extend([i for i in range(D//2)])        k.extend([k[-1] for _ in range(D//2, D+1)])        l = [0 for _ in range(D//2+1)]        l.extend([i for i in range(D+1-(D//2+1))])        l.append(l[-1])                # 空洞卷积        holes = [2**(kl[0]-kl[1])-1 for kl in zip(k,l)]        dilations = [i+1 for i in holes]                # 卷积层        self.conv = nn.ModuleList()        self.conv.append(nn.Conv2d(3, C, 3, padding=dilations[0], dilation=dilations[0]))        self.conv.extend([nn.Conv2d(C, C, 3, padding=dilations[i+1], dilation=dilations[i+1]) for i in range(D)])        self.conv.append(nn.Conv2d(C, 3, 3, padding=dilations[-1], dilation=dilations[-1]))        # Kaiming正态分布初始化,又叫啥He('s) initialization        for i in range(len(self.conv[:-1])):            nn.init.kaiming_normal_(self.conv[i].weight.data, nonlinearity='relu')                # batch norm        self.bn = nn.ModuleList()        self.bn.extend([nn.BatchNorm2d(C, C) for _ in range(D)])        # Batch norm layer 初始化权值        for i in range(D):            nn.init.constant_(self.bn[i].weight.data, 1.25 * np.sqrt(C))    def forward(self, x):        D = self.D        h = F.relu(self.conv[0](x))        h_buff = []        for i in range(D//2 - 1):            torch.backends.cudnn.benchmark = True            h = self.conv[i+1](h)            torch.backends.cudnn.benchmark = False            h = F.relu(self.bn[i](h))            h_buff.append(h)                    for i in range(D//2 - 1, D//2 + 1):            torch.backends.cudnn.benchmark = True            h = self.conv[i+1](h)            torch.backends.cudnn.benchmark = False            h = F.relu(self.bn[i](h))                    for i in range(D//2 + 1, D):            j = i - (D//2 + 1) + 1            torch.backends.cudnn.benchmark = True            h = self.conv[i+1]((h + h_buff[-j]) / np.sqrt(2))            torch.backends.cudnn.benchmark = False            h = F.relu(self.bn[i](h))                    y = self.conv[D+1](h) + x        return y

DUDnCNN 炼丹

lr = 1e-3net = DUDnCNN(6).to(device)adam = torch.optim.Adam(net.parameters(), lr=lr)stats_manager = DenoisingStatsManager()exp3 = nt.Experiment(net, train_set, test_set, adam, stats_manager, batch_size=4,                output_dir="./checkpoints/denoising3", perform_validation_during_training=True)
exp3
Net(DUDnCNN(  (mse): MSELoss()  (conv): ModuleList(    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (7): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  )  (bn): ModuleList(    (0): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (1): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (2): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (3): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (4): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (5): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)  )))TrainSet(NoisyBSDSDataset(mode=train, image_size=(180, 180), sigma=30))ValSet(NoisyBSDSDataset(mode=test, image_size=(320, 320), sigma=30))Optimizer(Adam (Parameter Group 0    amsgrad: False    betas: (0.9, 0.999)    eps: 1e-08    lr: 0.001    weight_decay: 0))StatsManager(DenoisingStatsManager)BatchSize(4)PerformValidationDuringTraining(True)



fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(9, 7))exp3.run(num_epochs=200, plot=lambda exp: plot(exp, fig=fig, axes=axes,                                                noisy=test_set[0][0]))
Start/Continue training from epoch 11Epoch 12 (Time: 19.62s)Epoch 13 (Time: 19.43s)Epoch 14 (Time: 19.42s)Epoch 15 (Time: 19.43s)Epoch 16 (Time: 19.46s)Epoch 17 (Time: 19.45s)Epoch 18 (Time: 19.37s)Epoch 19 (Time: 19.34s)Epoch 20 (Time: 19.42s)Epoch 21 (Time: 19.40s)Epoch 22 (Time: 19.40s)Epoch 23 (Time: 19.43s)Epoch 24 (Time: 19.43s)Epoch 25 (Time: 19.42s)Epoch 26 (Time: 19.40s)Epoch 27 (Time: 19.44s)Epoch 28 (Time: 19.35s)Epoch 29 (Time: 19.09s)Epoch 30 (Time: 19.07s)Epoch 31 (Time: 19.05s)Epoch 32 (Time: 19.12s)Epoch 33 (Time: 19.09s)Epoch 34 (Time: 19.17s)Epoch 35 (Time: 19.02s)Epoch 36 (Time: 19.15s)Epoch 37 (Time: 19.12s)Epoch 38 (Time: 19.00s)Epoch 39 (Time: 18.99s)Epoch 40 (Time: 18.98s)Epoch 41 (Time: 18.97s)Epoch 42 (Time: 18.96s)Epoch 43 (Time: 18.93s)Epoch 44 (Time: 18.97s)Epoch 45 (Time: 18.93s)Epoch 46 (Time: 18.96s)Epoch 47 (Time: 18.90s)Epoch 48 (Time: 18.94s)Epoch 49 (Time: 18.93s)Epoch 50 (Time: 18.97s)Epoch 51 (Time: 18.95s)Epoch 52 (Time: 18.92s)Epoch 53 (Time: 18.98s)Epoch 54 (Time: 18.94s)Epoch 55 (Time: 19.00s)Epoch 56 (Time: 18.92s)Epoch 57 (Time: 18.93s)Epoch 58 (Time: 19.00s)Epoch 59 (Time: 18.97s)Epoch 60 (Time: 18.96s)Epoch 61 (Time: 18.93s)Epoch 62 (Time: 18.94s)Epoch 63 (Time: 18.89s)Epoch 64 (Time: 18.90s)Epoch 65 (Time: 18.95s)Epoch 66 (Time: 18.89s)Epoch 67 (Time: 18.90s)Epoch 68 (Time: 18.91s)Epoch 69 (Time: 18.89s)Epoch 70 (Time: 18.96s)Epoch 71 (Time: 18.91s)Epoch 72 (Time: 18.96s)Epoch 73 (Time: 18.89s)Epoch 74 (Time: 18.93s)Epoch 75 (Time: 18.88s)Epoch 76 (Time: 18.90s)Epoch 77 (Time: 18.90s)Epoch 78 (Time: 18.90s)Epoch 79 (Time: 18.92s)Epoch 80 (Time: 18.94s)Epoch 81 (Time: 18.94s)Epoch 82 (Time: 18.97s)Epoch 83 (Time: 18.91s)Epoch 84 (Time: 18.94s)Epoch 85 (Time: 18.89s)Epoch 86 (Time: 18.92s)Epoch 87 (Time: 18.96s)Epoch 88 (Time: 18.95s)Epoch 89 (Time: 18.92s)Epoch 90 (Time: 18.92s)Epoch 91 (Time: 18.94s)Epoch 92 (Time: 18.89s)Epoch 93 (Time: 18.92s)Epoch 94 (Time: 18.87s)Epoch 95 (Time: 18.92s)Epoch 96 (Time: 18.90s)Epoch 97 (Time: 18.87s)Epoch 98 (Time: 19.00s)Epoch 99 (Time: 18.93s)Epoch 100 (Time: 18.98s)Epoch 101 (Time: 18.93s)Epoch 102 (Time: 18.95s)Epoch 103 (Time: 18.93s)Epoch 104 (Time: 18.94s)Epoch 105 (Time: 18.94s)Epoch 106 (Time: 18.97s)Epoch 107 (Time: 18.94s)Epoch 108 (Time: 18.99s)Epoch 109 (Time: 18.94s)Epoch 110 (Time: 18.98s)Epoch 111 (Time: 18.90s)Epoch 112 (Time: 18.95s)Epoch 113 (Time: 18.93s)Epoch 114 (Time: 18.96s)Epoch 115 (Time: 18.94s)Epoch 116 (Time: 18.95s)Epoch 117 (Time: 18.92s)Epoch 118 (Time: 19.01s)Epoch 119 (Time: 18.89s)Epoch 120 (Time: 18.91s)Epoch 121 (Time: 18.90s)Epoch 122 (Time: 19.00s)Epoch 123 (Time: 18.95s)Epoch 124 (Time: 18.96s)Epoch 125 (Time: 18.96s)Epoch 126 (Time: 18.91s)Epoch 127 (Time: 18.90s)Epoch 128 (Time: 18.95s)Epoch 129 (Time: 18.92s)Epoch 130 (Time: 18.97s)Epoch 131 (Time: 18.94s)Epoch 132 (Time: 18.94s)Epoch 133 (Time: 18.94s)Epoch 134 (Time: 18.99s)Epoch 135 (Time: 18.90s)Epoch 136 (Time: 18.98s)Epoch 137 (Time: 18.89s)Epoch 138 (Time: 18.96s)Epoch 139 (Time: 18.89s)Epoch 140 (Time: 18.95s)Epoch 141 (Time: 18.93s)Epoch 142 (Time: 18.90s)Epoch 143 (Time: 18.94s)Epoch 144 (Time: 18.92s)Epoch 145 (Time: 18.94s)Epoch 146 (Time: 18.94s)Epoch 147 (Time: 18.92s)Epoch 148 (Time: 18.96s)Epoch 149 (Time: 18.96s)Epoch 150 (Time: 18.94s)Epoch 151 (Time: 18.89s)Epoch 152 (Time: 18.91s)Epoch 153 (Time: 18.94s)Epoch 154 (Time: 18.90s)Epoch 155 (Time: 18.94s)Epoch 156 (Time: 18.95s)Epoch 157 (Time: 18.97s)Epoch 158 (Time: 19.00s)Epoch 159 (Time: 18.95s)Epoch 160 (Time: 19.01s)Epoch 161 (Time: 18.94s)Epoch 162 (Time: 19.02s)Epoch 163 (Time: 18.92s)Epoch 164 (Time: 18.97s)Epoch 165 (Time: 18.94s)Epoch 166 (Time: 18.92s)Epoch 167 (Time: 18.93s)Epoch 168 (Time: 18.90s)Epoch 169 (Time: 18.95s)Epoch 170 (Time: 18.92s)Epoch 171 (Time: 18.92s)Epoch 172 (Time: 19.02s)Epoch 173 (Time: 18.96s)Epoch 174 (Time: 19.02s)Epoch 175 (Time: 19.00s)Epoch 176 (Time: 18.97s)Epoch 177 (Time: 19.01s)Epoch 178 (Time: 18.99s)Epoch 179 (Time: 18.98s)Epoch 180 (Time: 18.98s)Epoch 181 (Time: 19.01s)Epoch 182 (Time: 18.95s)Epoch 183 (Time: 18.95s)Epoch 184 (Time: 19.00s)Epoch 185 (Time: 18.92s)Epoch 186 (Time: 19.00s)Epoch 187 (Time: 18.92s)Epoch 188 (Time: 18.98s)Epoch 189 (Time: 18.97s)Epoch 190 (Time: 18.99s)Epoch 191 (Time: 18.98s)Epoch 192 (Time: 18.93s)Epoch 193 (Time: 18.99s)Epoch 194 (Time: 19.01s)Epoch 195 (Time: 18.94s)Epoch 196 (Time: 18.95s)Epoch 197 (Time: 18.90s)Epoch 198 (Time: 18.95s)Epoch 199 (Time: 18.91s)Epoch 200 (Time: 18.94s)Finish training for 200 epochs



png

比较 DnCNN UDnCNN DUDnCNN

# DnCNNexp1.evaluate()
{'PSNR': tensor(29.0708), 'loss': 0.005108698001131415}



# UDnCNNexp2.evaluate()
{'PSNR': tensor(28.4299), 'loss': 0.005884343096986413}



# DUDnCNNexp3.evaluate()
{'PSNR': tensor(29.3069), 'loss': 0.004860227378085256}



num = 3img = []nets = [exp1.net, exp2.net, exp3.net]titles = ['noise','DnCNN', 'UDnCNN', 'DUDnCNN']fig, axes = plt.subplots(nrows=num, ncols=4, figsize=(20,15), sharex='all', sharey='all')for i in range(num):    myimshow(test_set[7*i+7][0], ax=axes[i][0])    x, _ = test_set[7*i+7]    x = x.unsqueeze(0).to(device)    img.append(x)for i in range(num):    for j in range(len(nets)):                model = nets[j].to(device)        model.eval()        with torch.no_grad():            y = model.forward(img[i])        myimshow(y[0], ax=axes[i][j+1])for i in range(num):    for j in range(len(titles)):        axes[i][j].set_title(f'{titles[j]}')

​
png
​

DUDnCNN 网络参数

exp3.net
DUDnCNN(  (mse): MSELoss()  (conv): ModuleList(    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (7): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  )  (bn): ModuleList(    (0): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (1): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (2): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (3): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (4): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)    (5): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)  ))



for name, param in exp3.net.named_parameters():    print(name, param.size(), param.requires_grad)
conv.0.weight torch.Size([64, 3, 3, 3]) Trueconv.0.bias torch.Size([64]) Trueconv.1.weight torch.Size([64, 64, 3, 3]) Trueconv.1.bias torch.Size([64]) Trueconv.2.weight torch.Size([64, 64, 3, 3]) Trueconv.2.bias torch.Size([64]) Trueconv.3.weight torch.Size([64, 64, 3, 3]) Trueconv.3.bias torch.Size([64]) Trueconv.4.weight torch.Size([64, 64, 3, 3]) Trueconv.4.bias torch.Size([64]) Trueconv.5.weight torch.Size([64, 64, 3, 3]) Trueconv.5.bias torch.Size([64]) Trueconv.6.weight torch.Size([64, 64, 3, 3]) Trueconv.6.bias torch.Size([64]) Trueconv.7.weight torch.Size([3, 64, 3, 3]) Trueconv.7.bias torch.Size([3]) Truebn.0.weight torch.Size([64]) Truebn.0.bias torch.Size([64]) Truebn.1.weight torch.Size([64]) Truebn.1.bias torch.Size([64]) Truebn.2.weight torch.Size([64]) Truebn.2.bias torch.Size([64]) Truebn.3.weight torch.Size([64]) Truebn.3.bias torch.Size([64]) Truebn.4.weight torch.Size([64]) Truebn.4.bias torch.Size([64]) Truebn.5.weight torch.Size([64]) Truebn.5.bias torch.Size([64]) True

参数个数还是不变 3456 + 36864 x D

感受野 (Receptive Field) 计算:

$R_D=(1+2+\sum_{i=1}^{D/2}2^i\times 2^{i-1}+2^{D/2}\times 2^{D/2-1}+\sum_{i=1}^{D/2-1}2^i\times 2^{i-1}+2)^2$.

$R_6=89^2$.


DnCNN UDnCNN DUDnCNN with LeakyReLU

VII. 1. 下载数据集

!wget -N https://raw.githubusercontent.com/eebowen/Transfer-Learning-and-Deep-Neural-Network-Acceleration-for-Image-Classification/master/nntools.py
!wget -N https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz
!tar -zxvf BSDS300-images.tgz
dataset_root_dir = './BSDS300/images/'
--2021-06-17 12:44:18--  https://raw.githubusercontent.com/eebowen/Transfer-Learning-and-Deep-Neural-Network-Acceleration-for-Image-Classification/master/nntools.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12781 (12K) [text/plain]
Saving to: ‘nntools.py’

nntools.py          100%[===================>]  12.48K  --.-KB/s    in 0s      

Last-modified header missing -- time-stamps turned off.
2021-06-17 12:44:18 (92.6 MB/s) - ‘nntools.py’ saved [12781/12781]

--2021-06-17 12:44:18--  https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz
Resolving www2.eecs.berkeley.edu (www2.eecs.berkeley.edu)... 128.32.244.190
Connecting to www2.eecs.berkeley.edu (www2.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 22211847 (21M) [application/x-tar]
Saving to: ‘BSDS300-images.tgz’

BSDS300-images.tgz  100%[===================>]  21.18M  6.64MB/s    in 3.2s    

2021-06-17 12:44:21 (6.64 MB/s) - ‘BSDS300-images.tgz’ saved [22211847/22211847]


%matplotlib inline

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td
import torchvision as tv
from PIL import Image
import matplotlib.pyplot as plt
import nntools as nt
import time
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)
cuda:0

VIII. 2. 训练集加噪

$σ = 30$ 高斯噪音 $180 × 180$ 裁剪(左上角或随机位置)

class NoisyBSDSDataset(td.Dataset):

    def __init__(self, root_dir, mode='train', image_size=(180, 180), sigma=30):
        super(NoisyBSDSDataset, self).__init__()
        self.mode = mode
        self.image_size = image_size
        self.sigma = sigma
        self.images_dir = os.path.join(root_dir, mode)
        self.files = os.listdir(self.images_dir)

    def __len__(self):
        return len(self.files)

    def __repr__(self):
        return "NoisyBSDSDataset(mode={}, image_size={}, sigma={})". \
            format(self.mode, self.image_size, self.sigma)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.files[idx])
        clean = Image.open(img_path).convert('RGB')   
        # 随机裁剪
        #i = np.random.randint(clean.size[0] - self.image_size[0])
        #j = np.random.randint(clean.size[1] - self.image_size[1])
        i=0
        j=0
        clean = clean.crop([i, j, i+self.image_size[0], j+self.image_size[1]])
        transform = tv.transforms.Compose([
            # 转换张量
            tv.transforms.ToTensor(),
            # [−1, 1]
            tv.transforms.Normalize((.5, .5, .5), (.5, .5, .5))
            ])
        clean = transform(clean)
        
        noisy = clean + 2 / 255 * self.sigma * torch.randn(clean.shape)
        return noisy, clean
def myimshow(image, ax=plt):
    image = image.to('cpu').numpy()
    image = np.moveaxis(image, [0, 1, 2], [2, 0, 1])
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    h = ax.imshow(image)
    ax.axis('off')
    return h

导训练集和测试集进来

train_set = NoisyBSDSDataset(dataset_root_dir)
test_set = NoisyBSDSDataset(dataset_root_dir, mode='test', image_size=(320, 320))
x = test_set[0]
fig, axes = plt.subplots(ncols=2)
myimshow(x[0], ax=axes[0])
axes[0].set_title('Noisy')
myimshow(x[1], ax=axes[1])
axes[1].set_title('Clean')
print(f'image size is {x[0].shape}.')
image size is torch.Size([3, 320, 320]).



png

IX. 3. DnCNN

loss 用的均方差

class NNRegressor(nt.NeuralNetwork):

    def __init__(self):
        super(NNRegressor, self).__init__()
        self.mse = nn.MSELoss()

    def criterion(self, y, d):
        return self.mse(y, d)

CNN 网络为啥要带个权重

看这个台湾人写的比较清楚
深度學習: Weight initialization和Batch Normalization

无权初始化

class DnCNN(NNRegressor):

    def __init__(self, D, C=64):
        super(DnCNN, self).__init__()
        self.D = D
        
        self.conv = nn.ModuleList()
        self.conv.append(nn.Conv2d(3, C, 3, padding=1))
        self.conv.extend([nn.Conv2d(C, C, 3, padding=1) for _ in range(D)])
        self.conv.append(nn.Conv2d(C, 3, 3, padding=1))
        
        self.bn = nn.ModuleList()
        for k in range(D):
            self.bn.append(nn.BatchNorm2d(C, C))

    def forward(self, x):
        D = self.D
        h = F.relu(self.conv[0](x))
        for i in range(D):
            h = F.relu(self.bn[i](self.conv[i+1](h)))
        y = self.conv[D+1](h) + x
        return y

零填充(泛卷积)对输入图像矩阵的边缘进行滤波,是玄学

x, _ = train_set[-1]
x = x.unsqueeze(0).to(device)
Ds = [0, 1, 2, 4, 8]

fig, axes = plt.subplots(nrows=len(Ds), ncols=3, figsize=(9,9))
for i in range(len(Ds)):
    with torch.no_grad():
        model = DnCNN(Ds[i]).to(device)
        y = model.forward(x) # 4-d
    # 3-d
    myimshow(x[0], ax=axes[i][0])
    axes[i][0].set_title('x[0]')
    myimshow(y[0], ax=axes[i][1])
    axes[i][1].set_title(f'y[0] (D={Ds[i]})')
    myimshow(x[0]-y[0], ax=axes[i][2])
    axes[i][2].set_title(f'x[0]-y[0] (D={Ds[i]})')

png

D=0 才有残差输出,梯度消失,没法炼丹

带权跑一下

class DnCNN(NNRegressor):

    def __init__(self, D, C=64):
        super(DnCNN, self).__init__()
        self.D = D
        
        self.conv = nn.ModuleList()
        self.conv.append(nn.Conv2d(3, C, 3, padding=1))
        self.conv.extend([nn.Conv2d(C, C, 3, padding=1) for _ in range(D)])
        self.conv.append(nn.Conv2d(C, 3, 3, padding=1))
        # Kaiming正态分布初始化,又叫啥He('s) initialization
        for i in range(len(self.conv[:-1])):
            nn.init.kaiming_normal_(self.conv[i].weight.data, nonlinearity='leaky_relu')
        
        # Batch norm
        self.bn = nn.ModuleList()
        self.bn.extend([nn.BatchNorm2d(C, C) for _ in range(D)])
        # Batch norm layer 初始化权值
        for i in range(D):
            nn.init.constant_(self.bn[i].weight.data, 1.25 * np.sqrt(C))

    def forward(self, x):
        D = self.D
        h = F.leaky_relu(self.conv[0](x))
        for i in range(D):
            h = F.leaky_relu(self.bn[i](self.conv[i+1](h)))
        y = self.conv[D+1](h) + x
        return y
x, _ = train_set[-1]
x = x.unsqueeze(0).to(device)
Ds = [0, 1, 2, 4, 8]

fig, axes = plt.subplots(nrows=len(Ds), ncols=3, figsize=(9,9))
for i in range(len(Ds)):
    with torch.no_grad():
        model = DnCNN(Ds[i]).to(device)
        y = model.forward(x)
    
    myimshow(x[0], ax=axes[i][0])
    axes[i][0].set_title('x[0]')
    myimshow(y[0], ax=axes[i][1])
    axes[i][1].set_title(f'y[0] (D={Ds[i]})')
    myimshow(x[0]-y[0], ax=axes[i][2])
    axes[i][2].set_title(f'x[0]-y[0] (D={Ds[i]})')

png

然后残差就非零,能梯度下降,能炼丹

PSNR

峰值信噪比 PSNR (Peak Signal-to-Noise-Ratio),值域是 [−1, 1]

$$ PSNR = 10\log_{10}\frac{4n}{\Vert y-d\Vert_2^2} $$

d 是理想值, y 是估计值,分母就是均方差, n 是张量大小,对数定义,单位是 dB,数越大越好

这东西是要算平均的

class DenoisingStatsManager(nt.StatsManager):

    def __init__(self):
        super(DenoisingStatsManager, self).__init__()

    def init(self):
        super(DenoisingStatsManager, self).init()
        self.running_psnr = 0

    def accumulate(self, loss, x, y, d):
        super(DenoisingStatsManager, self).accumulate(loss, x, y, d)    
        n = x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3]
        self.running_psnr += 10*torch.log10(4*n/(torch.norm(y-d)**2))

    def summarize(self):
        loss = super(DenoisingStatsManager, self).summarize()
        psnr = self.running_psnr / self.number_update
        return {'loss': loss, 'PSNR': psnr.cpu()}
def plot(exp, fig, axes, noisy, visu_rate=2):
    if exp.epoch % visu_rate != 0:
        return
    with torch.no_grad():
        denoised = exp.net(noisy[None].to(net.device))[0]
    axes[0][0].clear()
    axes[0][1].clear()
    axes[1][0].clear()
    axes[1][1].clear()
    myimshow(noisy, ax=axes[0][0])
    axes[0][0].set_title('Noisy image')
    
    myimshow(denoised, ax=axes[0][1])
    axes[0][1].set_title('Denoised image')
    
    axes[1][0].plot([exp.history[k][0]['loss'] for k in range(exp.epoch)], label='training loss')
    axes[1][0].set_ylabel('Loss')
    axes[1][0].set_xlabel('Epoch')
    axes[1][0].legend()
    
    axes[1][1].plot([exp.history[k][0]['PSNR'] for k in range(exp.epoch)], label='training psnr')
    axes[1][1].set_ylabel('PSNR')
    axes[1][1].set_xlabel('Epoch')
    axes[1][1].legend()
    
    plt.tight_layout()
    fig.canvas.draw()

DnCNN 炼丹

lr = 1e-3
net = DnCNN(6).to(device)
adam = torch.optim.Adam(net.parameters(), lr=lr)
stats_manager = DenoisingStatsManager()
exp1 = nt.Experiment(net, train_set, test_set, adam, stats_manager, batch_size=4, 
               output_dir="./checkpoints/denoising1", perform_validation_during_training=True)
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(9, 7))
exp1.run(num_epochs=200, plot=lambda exp: plot(exp, fig=fig, axes=axes,
                                                noisy=test_set[0][0]))
Start/Continue training from epoch 200
Finish training for 200 epochs



png

效果

img = []
model = exp1.net.to(device)
titles = ['clean', 'noise', 'DnCNN']

x, clean = test_set[0]
x = x.unsqueeze(0).to(device)
img.append(clean)
img.append(x[0])

model.eval()
with torch.no_grad():
    y = model.forward(x)
img.append(y[0])
    
fig, axes = plt.subplots(ncols=3, figsize=(20,10), sharex='all', sharey='all')
for i in range(len(img)):
    myimshow(img[i], ax=axes[i])
    axes[i].set_title(f'{titles[i]}')

png

仍存在噪点,有信息缺失

DnCNN 网络参数

for name, param in model.named_parameters():
    print(name, param.size(), param.requires_grad)
conv.0.weight torch.Size([64, 3, 3, 3]) True
conv.0.bias torch.Size([64]) True
conv.1.weight torch.Size([64, 64, 3, 3]) True
conv.1.bias torch.Size([64]) True
conv.2.weight torch.Size([64, 64, 3, 3]) True
conv.2.bias torch.Size([64]) True
conv.3.weight torch.Size([64, 64, 3, 3]) True
conv.3.bias torch.Size([64]) True
conv.4.weight torch.Size([64, 64, 3, 3]) True
conv.4.bias torch.Size([64]) True
conv.5.weight torch.Size([64, 64, 3, 3]) True
conv.5.bias torch.Size([64]) True
conv.6.weight torch.Size([64, 64, 3, 3]) True
conv.6.bias torch.Size([64]) True
conv.7.weight torch.Size([3, 64, 3, 3]) True
conv.7.bias torch.Size([3]) True
bn.0.weight torch.Size([64]) True
bn.0.bias torch.Size([64]) True
bn.1.weight torch.Size([64]) True
bn.1.bias torch.Size([64]) True
bn.2.weight torch.Size([64]) True
bn.2.bias torch.Size([64]) True
bn.3.weight torch.Size([64]) True
bn.3.bias torch.Size([64]) True
bn.4.weight torch.Size([64]) True
bn.4.bias torch.Size([64]) True
bn.5.weight torch.Size([64]) True
bn.5.bias torch.Size([64]) True

参数个数

第一层有 64 x 3 x 3 x 3 个 parameter. D 层则有 64 x 64 x 3 x 3 x D. 最后一层是 3 x 64 x 3 x 3 . 总共 3456 + 36864 x D

感受野 (Receptive Field) 计算:

没有池化层,每层固定增加 $2^{0-0+1}=2$,初始输入层是 1 ,小学奥数得到 $R_D=(1+2\times (D+2))^2$.

D=6 为例, $R_6=17^2$.

据说(待考证)σ = 30 高斯噪声下的降噪,单个像素应受到 33 × 33 个像素影响,据此来确定深度。

感受野 $R_D=(1+2\times (D+2)) \times (1+2\times (D+2))$, 令等于33 得到 $D=14$,参数个数 $3456 + 36864 \times 14 = 519552$。

X. 4. UDnCNN

U-net like CNNs

class UDnCNN(NNRegressor):

    def __init__(self, D, C=64):
        super(UDnCNN, self).__init__()
        self.D = D
        
        self.conv = nn.ModuleList()
        self.conv.append(nn.Conv2d(3, C, 3, padding=1))
        self.conv.extend([nn.Conv2d(C, C, 3, padding=1) for _ in range(D)])
        self.conv.append(nn.Conv2d(C, 3, 3, padding=1))
        # Kaiming正态分布初始化,又叫啥He('s) initialization
        for i in range(len(self.conv[:-1])):
            nn.init.kaiming_normal_(self.conv[i].weight.data, nonlinearity='leaky_relu')
        
        # batch norm
        self.bn = nn.ModuleList()
        self.bn.extend([nn.BatchNorm2d(C, C) for _ in range(D)])
        # Batch norm layer 初始化权值
        for i in range(D):
            nn.init.constant_(self.bn[i].weight.data, 1.25 * np.sqrt(C))
    # 前面都一样,这里搞个U-Net
    def forward(self, x):
        D = self.D
        h = F.leaky_relu(self.conv[0](x))
        h_buff = []
        idx_buff = []
        shape_buff = []
        for i in range(D//2-1):
            shape_buff.append(h.shape)
            h, idx = F.max_pool2d(F.leaky_relu(self.bn[i](self.conv[i+1](h))), 
                                  kernel_size=(2,2), return_indices=True)
            h_buff.append(h)
            idx_buff.append(idx)
        for i in range(D//2-1, D//2+1):
            h = F.leaky_relu(self.bn[i](self.conv[i+1](h)))
        for i in range(D//2+1, D):
            j = i - (D//2 + 1) + 1
            h = F.max_unpool2d(F.leaky_relu(self.bn[i](self.conv[i+1]((h+h_buff[-j])/np.sqrt(2)))), 
                               idx_buff[-j], kernel_size=(2,2), output_size=shape_buff[-j])
        y = self.conv[D+1](h) + x
        return y

UDnCNN 炼丹

lr = 1e-3
net = UDnCNN(6).to(device)
adam = torch.optim.Adam(net.parameters(), lr=lr)
stats_manager = DenoisingStatsManager()
exp2 = nt.Experiment(net, train_set, test_set, adam, stats_manager, batch_size=4, 
               output_dir="/content/checkpoints/denoising2", perform_validation_during_training=True)
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(9, 7))
exp2.run(num_epochs=200, plot=lambda exp: plot(exp, fig=fig, axes=axes,
                                                noisy=test_set[0][0]))
Start/Continue training from epoch 50
Epoch 51 (Time: 3.56s)
Epoch 52 (Time: 3.37s)
Epoch 53 (Time: 3.42s)
Epoch 54 (Time: 3.36s)
Epoch 55 (Time: 3.37s)
Epoch 56 (Time: 3.35s)
Epoch 57 (Time: 3.36s)
Epoch 58 (Time: 3.38s)
Epoch 59 (Time: 3.35s)
Epoch 60 (Time: 3.34s)
Epoch 61 (Time: 3.36s)
Epoch 62 (Time: 3.41s)
Epoch 63 (Time: 3.39s)
Epoch 64 (Time: 3.35s)
Epoch 65 (Time: 3.35s)
Epoch 66 (Time: 3.37s)
Epoch 67 (Time: 3.36s)
Epoch 68 (Time: 3.36s)
Epoch 69 (Time: 3.37s)
Epoch 70 (Time: 3.37s)
Epoch 71 (Time: 3.39s)
Epoch 72 (Time: 3.39s)
Epoch 73 (Time: 3.38s)
Epoch 74 (Time: 3.36s)
Epoch 75 (Time: 3.38s)
Epoch 76 (Time: 3.36s)
Epoch 77 (Time: 3.34s)
Epoch 78 (Time: 3.36s)
Epoch 79 (Time: 3.36s)
Epoch 80 (Time: 3.36s)
Epoch 81 (Time: 3.40s)
Epoch 82 (Time: 3.36s)
Epoch 83 (Time: 3.38s)
Epoch 84 (Time: 3.37s)
Epoch 85 (Time: 3.36s)
Epoch 86 (Time: 3.38s)
Epoch 87 (Time: 3.39s)
Epoch 88 (Time: 3.38s)
Epoch 89 (Time: 3.38s)
Epoch 90 (Time: 3.36s)
Epoch 91 (Time: 3.38s)
Epoch 92 (Time: 3.39s)
Epoch 93 (Time: 3.37s)
Epoch 94 (Time: 3.39s)
Epoch 95 (Time: 3.38s)
Epoch 96 (Time: 3.38s)
Epoch 97 (Time: 3.36s)
Epoch 98 (Time: 3.36s)
Epoch 99 (Time: 3.35s)
Epoch 100 (Time: 3.41s)
Epoch 101 (Time: 3.35s)
Epoch 102 (Time: 3.38s)
Epoch 103 (Time: 3.37s)
Epoch 104 (Time: 3.37s)
Epoch 105 (Time: 3.39s)
Epoch 106 (Time: 3.38s)
Epoch 107 (Time: 3.39s)
Epoch 108 (Time: 3.35s)
Epoch 109 (Time: 3.46s)
Epoch 110 (Time: 3.38s)
Epoch 111 (Time: 3.36s)
Epoch 112 (Time: 3.40s)
Epoch 113 (Time: 3.38s)
Epoch 114 (Time: 3.40s)
Epoch 115 (Time: 3.36s)
Epoch 116 (Time: 3.36s)
Epoch 117 (Time: 3.40s)
Epoch 118 (Time: 3.38s)
Epoch 119 (Time: 3.39s)
Epoch 120 (Time: 3.38s)
Epoch 121 (Time: 3.40s)
Epoch 122 (Time: 3.36s)
Epoch 123 (Time: 3.37s)
Epoch 124 (Time: 3.37s)
Epoch 125 (Time: 3.38s)
Epoch 126 (Time: 3.40s)
Epoch 127 (Time: 3.39s)
Epoch 128 (Time: 3.41s)
Epoch 129 (Time: 3.35s)
Epoch 130 (Time: 3.37s)
Epoch 131 (Time: 3.36s)
Epoch 132 (Time: 3.36s)
Epoch 133 (Time: 3.36s)
Epoch 134 (Time: 3.37s)
Epoch 135 (Time: 3.34s)
Epoch 136 (Time: 3.39s)
Epoch 137 (Time: 3.42s)
Epoch 138 (Time: 3.40s)
Epoch 139 (Time: 3.39s)
Epoch 140 (Time: 3.40s)
Epoch 141 (Time: 3.40s)
Epoch 142 (Time: 3.40s)
Epoch 143 (Time: 3.40s)
Epoch 144 (Time: 3.37s)
Epoch 145 (Time: 3.38s)
Epoch 146 (Time: 3.36s)
Epoch 147 (Time: 3.39s)
Epoch 148 (Time: 3.40s)
Epoch 149 (Time: 3.35s)
Epoch 150 (Time: 3.38s)
Epoch 151 (Time: 3.38s)
Epoch 152 (Time: 3.38s)
Epoch 153 (Time: 3.36s)
Epoch 154 (Time: 3.38s)
Epoch 155 (Time: 3.38s)
Epoch 156 (Time: 3.41s)
Epoch 157 (Time: 3.40s)
Epoch 158 (Time: 3.39s)
Epoch 159 (Time: 3.38s)
Epoch 160 (Time: 3.41s)
Epoch 161 (Time: 3.39s)
Epoch 162 (Time: 3.39s)
Epoch 163 (Time: 3.40s)
Epoch 164 (Time: 3.40s)
Epoch 165 (Time: 3.41s)
Epoch 166 (Time: 3.39s)
Epoch 167 (Time: 3.37s)
Epoch 168 (Time: 3.39s)
Epoch 169 (Time: 3.39s)
Epoch 170 (Time: 3.37s)
Epoch 171 (Time: 3.39s)
Epoch 172 (Time: 3.38s)
Epoch 173 (Time: 3.40s)
Epoch 174 (Time: 3.40s)
Epoch 175 (Time: 3.42s)
Epoch 176 (Time: 3.39s)
Epoch 177 (Time: 3.42s)
Epoch 178 (Time: 3.40s)
Epoch 179 (Time: 3.37s)
Epoch 180 (Time: 3.38s)
Epoch 181 (Time: 3.37s)
Epoch 182 (Time: 3.41s)
Epoch 183 (Time: 3.40s)
Epoch 184 (Time: 3.47s)
Epoch 185 (Time: 3.37s)
Epoch 186 (Time: 3.40s)
Epoch 187 (Time: 3.40s)
Epoch 188 (Time: 3.40s)
Epoch 189 (Time: 3.38s)
Epoch 190 (Time: 3.39s)
Epoch 191 (Time: 3.38s)
Epoch 192 (Time: 3.38s)
Epoch 193 (Time: 3.37s)
Epoch 194 (Time: 3.43s)
Epoch 195 (Time: 3.37s)
Epoch 196 (Time: 3.38s)
Epoch 197 (Time: 3.38s)
Epoch 198 (Time: 3.40s)
Epoch 199 (Time: 3.40s)
Epoch 200 (Time: 3.37s)
Finish training for 200 epochs



png

UDnCNN 网络参数

for name, param in exp2.net.named_parameters():
    print(name, param.size(), param.requires_grad)
conv.0.weight torch.Size([64, 3, 3, 3]) True
conv.0.bias torch.Size([64]) True
conv.1.weight torch.Size([64, 64, 3, 3]) True
conv.1.bias torch.Size([64]) True
conv.2.weight torch.Size([64, 64, 3, 3]) True
conv.2.bias torch.Size([64]) True
conv.3.weight torch.Size([64, 64, 3, 3]) True
conv.3.bias torch.Size([64]) True
conv.4.weight torch.Size([64, 64, 3, 3]) True
conv.4.bias torch.Size([64]) True
conv.5.weight torch.Size([64, 64, 3, 3]) True
conv.5.bias torch.Size([64]) True
conv.6.weight torch.Size([64, 64, 3, 3]) True
conv.6.bias torch.Size([64]) True
conv.7.weight torch.Size([3, 64, 3, 3]) True
conv.7.bias torch.Size([3]) True
bn.0.weight torch.Size([64]) True
bn.0.bias torch.Size([64]) True
bn.1.weight torch.Size([64]) True
bn.1.bias torch.Size([64]) True
bn.2.weight torch.Size([64]) True
bn.2.bias torch.Size([64]) True
bn.3.weight torch.Size([64]) True
bn.3.bias torch.Size([64]) True
bn.4.weight torch.Size([64]) True
bn.4.bias torch.Size([64]) True
bn.5.weight torch.Size([64]) True
bn.5.bias torch.Size([64]) True

池化不改变参数个数,还是 3456 + 36864 x D.

感受野 (Receptive Field) 计算:

$R_D=(1+\sum_{i=1}^{D/2}2^i+2\times 2^{D/2}+\sum_{i=1}^{D/2-1}2^i+2)^2$.

D=6 为例 $R_6(1+(2+4+8)+(2\times 8)+(4+2)+2)^2=39^2$.

从 PSNR 看,加了 U-Net 更烂了... 池化会丢失价值特征信息,但图上看起来更“好看”了...

比比 DnCNN 和 UDnCNN

# DnCNN
exp1.evaluate()
{'PSNR': tensor(29.0894), 'loss': 0.005089376987889409}



# UDnCNN
exp2.evaluate()
{'PSNR': tensor(28.3012), 'loss': 0.0060464405920356516}



img = []
titles = ['clean', 'noise', 'DnCNN','UDnCNN']

x, clean = test_set[0]
x = x.unsqueeze(0).to(device)
img.append(clean)
img.append(x[0])

model = exp1.net.to(device)
model.eval()
with torch.no_grad():
    y = model.forward(x)
img.append(y[0])

model = exp2.net.to(device)
model.eval()
with torch.no_grad():
    y = model.forward(x)
img.append(y[0])
    
fig, axes = plt.subplots(ncols=4, figsize=(20,10), sharex='all', sharey='all')
for i in range(len(img)):
    myimshow(img[i], ax=axes[i])
    axes[i].set_title(f'{titles[i]}')

png

XI. 5. DUDnCNN

U-net like CNNs with dilated convolutions

空洞卷积(dilated convolution)代替池化来增大感受野(Receptive Field)

然而 pytorch 空洞卷积跑的贼慢,原理上看应该和普通卷积差不多快,这里有个优化的问题...

空洞卷积之前
torch.backends.cudnn.benchmark=True 之后改回 torch.backends.cudnn.benchmark=False 可以提速,详见 https://github.com/pytorch/pytorch/issues/15054.

class DUDnCNN(NNRegressor):

    def __init__(self, D, C=64):
        super(DUDnCNN, self).__init__()
        self.D = D
        
        # compute k(max_pool) and l(max_unpool)
        k = [0]
        k.extend([i for i in range(D//2)])
        k.extend([k[-1] for _ in range(D//2, D+1)])
        l = [0 for _ in range(D//2+1)]
        l.extend([i for i in range(D+1-(D//2+1))])
        l.append(l[-1])
        
        # 空洞卷积
        holes = [2**(kl[0]-kl[1])-1 for kl in zip(k,l)]
        dilations = [i+1 for i in holes]
        
        # 卷积层
        self.conv = nn.ModuleList()
        self.conv.append(nn.Conv2d(3, C, 3, padding=dilations[0], dilation=dilations[0]))
        self.conv.extend([nn.Conv2d(C, C, 3, padding=dilations[i+1], dilation=dilations[i+1]) for i in range(D)])
        self.conv.append(nn.Conv2d(C, 3, 3, padding=dilations[-1], dilation=dilations[-1]))
        # Kaiming正态分布初始化,又叫啥He('s) initialization
        for i in range(len(self.conv[:-1])):
            nn.init.kaiming_normal_(self.conv[i].weight.data, nonlinearity='leaky_relu')
        
        # batch norm
        self.bn = nn.ModuleList()
        self.bn.extend([nn.BatchNorm2d(C, C) for _ in range(D)])
        # Batch norm layer 初始化权值
        for i in range(D):
            nn.init.constant_(self.bn[i].weight.data, 1.25 * np.sqrt(C))

    def forward(self, x):
        D = self.D
        h = F.leaky_relu(self.conv[0](x))
        h_buff = []

        for i in range(D//2 - 1):
            torch.backends.cudnn.benchmark = True
            h = self.conv[i+1](h)
            torch.backends.cudnn.benchmark = False
            h = F.leaky_relu(self.bn[i](h))
            h_buff.append(h)
            
        for i in range(D//2 - 1, D//2 + 1):
            torch.backends.cudnn.benchmark = True
            h = self.conv[i+1](h)
            torch.backends.cudnn.benchmark = False
            h = F.leaky_relu(self.bn[i](h))
            
        for i in range(D//2 + 1, D):
            j = i - (D//2 + 1) + 1
            torch.backends.cudnn.benchmark = True
            h = self.conv[i+1]((h + h_buff[-j]) / np.sqrt(2))
            torch.backends.cudnn.benchmark = False
            h = F.leaky_relu(self.bn[i](h))
            
        y = self.conv[D+1](h) + x
        return y

DUDnCNN 炼丹

lr = 1e-3
net = DUDnCNN(6).to(device)
adam = torch.optim.Adam(net.parameters(), lr=lr)
stats_manager = DenoisingStatsManager()
exp3 = nt.Experiment(net, train_set, test_set, adam, stats_manager, batch_size=4, 
               output_dir="./checkpoints/denoising3", perform_validation_during_training=True)
exp3
Net(DUDnCNN(
  (mse): MSELoss()
  (conv): ModuleList(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (bn): ModuleList(
    (0): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (4): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (5): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
  )
))
TrainSet(NoisyBSDSDataset(mode=train, image_size=(180, 180), sigma=30))
ValSet(NoisyBSDSDataset(mode=test, image_size=(320, 320), sigma=30))
Optimizer(Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
))
StatsManager(DenoisingStatsManager)
BatchSize(4)
PerformValidationDuringTraining(True)



fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(9, 7))
exp3.run(num_epochs=200, plot=lambda exp: plot(exp, fig=fig, axes=axes,
                                                noisy=test_set[0][0]))
Start/Continue training from epoch 50
Epoch 51 (Time: 6.67s)
Epoch 52 (Time: 6.41s)
Epoch 53 (Time: 6.42s)
Epoch 54 (Time: 6.46s)
Epoch 55 (Time: 6.48s)
Epoch 56 (Time: 6.48s)
Epoch 57 (Time: 6.56s)
Epoch 58 (Time: 6.58s)
Epoch 59 (Time: 6.56s)
Epoch 60 (Time: 6.60s)
Epoch 61 (Time: 6.59s)
Epoch 62 (Time: 6.62s)
Epoch 63 (Time: 6.56s)
Epoch 64 (Time: 6.57s)
Epoch 65 (Time: 6.55s)
Epoch 66 (Time: 6.57s)
Epoch 67 (Time: 6.58s)
Epoch 68 (Time: 6.58s)
Epoch 69 (Time: 6.58s)
Epoch 70 (Time: 6.60s)
Epoch 71 (Time: 6.60s)
Epoch 72 (Time: 6.61s)
Epoch 73 (Time: 6.64s)
Epoch 74 (Time: 6.64s)
Epoch 75 (Time: 6.62s)
Epoch 76 (Time: 6.65s)
Epoch 77 (Time: 6.65s)
Epoch 78 (Time: 6.64s)
Epoch 79 (Time: 6.65s)
Epoch 80 (Time: 6.64s)
Epoch 81 (Time: 6.67s)
Epoch 82 (Time: 6.66s)
Epoch 83 (Time: 6.64s)
Epoch 84 (Time: 6.66s)
Epoch 85 (Time: 6.66s)
Epoch 86 (Time: 6.65s)
Epoch 87 (Time: 6.65s)
Epoch 88 (Time: 6.67s)
Epoch 89 (Time: 6.70s)
Epoch 90 (Time: 6.68s)
Epoch 91 (Time: 6.70s)
Epoch 92 (Time: 6.70s)
Epoch 93 (Time: 6.66s)
Epoch 94 (Time: 6.69s)
Epoch 95 (Time: 6.68s)
Epoch 96 (Time: 6.68s)
Epoch 97 (Time: 6.68s)
Epoch 98 (Time: 6.70s)
Epoch 99 (Time: 6.70s)
Epoch 100 (Time: 6.73s)
Epoch 101 (Time: 6.71s)
Epoch 102 (Time: 6.72s)
Epoch 103 (Time: 6.73s)
Epoch 104 (Time: 6.71s)
Epoch 105 (Time: 6.70s)
Epoch 106 (Time: 6.70s)
Epoch 107 (Time: 6.72s)
Epoch 108 (Time: 6.75s)
Epoch 109 (Time: 6.72s)
Epoch 110 (Time: 6.70s)
Epoch 111 (Time: 6.67s)
Epoch 112 (Time: 6.71s)
Epoch 113 (Time: 6.72s)
Epoch 114 (Time: 6.73s)
Epoch 115 (Time: 6.71s)
Epoch 116 (Time: 6.74s)
Epoch 117 (Time: 6.75s)
Epoch 118 (Time: 6.73s)
Epoch 119 (Time: 6.71s)
Epoch 120 (Time: 6.71s)
Epoch 121 (Time: 6.70s)
Epoch 122 (Time: 6.70s)
Epoch 123 (Time: 6.71s)
Epoch 124 (Time: 6.68s)
Epoch 125 (Time: 6.73s)
Epoch 126 (Time: 6.72s)
Epoch 127 (Time: 6.73s)
Epoch 128 (Time: 6.70s)
Epoch 129 (Time: 6.71s)
Epoch 130 (Time: 6.68s)
Epoch 131 (Time: 6.71s)
Epoch 132 (Time: 6.73s)
Epoch 133 (Time: 6.69s)
Epoch 134 (Time: 6.68s)
Epoch 135 (Time: 6.70s)
Epoch 136 (Time: 6.71s)
Epoch 137 (Time: 6.72s)
Epoch 138 (Time: 6.72s)
Epoch 139 (Time: 6.69s)
Epoch 140 (Time: 6.68s)
Epoch 141 (Time: 6.68s)
Epoch 142 (Time: 6.72s)
Epoch 143 (Time: 6.70s)
Epoch 144 (Time: 6.70s)
Epoch 145 (Time: 6.71s)
Epoch 146 (Time: 6.70s)
Epoch 147 (Time: 6.72s)
Epoch 148 (Time: 6.70s)
Epoch 149 (Time: 6.71s)
Epoch 150 (Time: 6.70s)
Epoch 151 (Time: 6.73s)
Epoch 152 (Time: 6.72s)
Epoch 153 (Time: 6.69s)
Epoch 154 (Time: 6.71s)
Epoch 155 (Time: 6.69s)
Epoch 156 (Time: 6.69s)
Epoch 157 (Time: 6.68s)
Epoch 158 (Time: 6.69s)
Epoch 159 (Time: 6.69s)
Epoch 160 (Time: 6.69s)
Epoch 161 (Time: 6.71s)
Epoch 162 (Time: 6.71s)
Epoch 163 (Time: 6.71s)
Epoch 164 (Time: 6.68s)
Epoch 165 (Time: 6.67s)
Epoch 166 (Time: 6.71s)
Epoch 167 (Time: 6.70s)
Epoch 168 (Time: 6.70s)
Epoch 169 (Time: 6.70s)
Epoch 170 (Time: 6.70s)
Epoch 171 (Time: 6.72s)
Epoch 172 (Time: 6.68s)
Epoch 173 (Time: 6.70s)
Epoch 174 (Time: 6.72s)
Epoch 175 (Time: 6.70s)
Epoch 176 (Time: 6.70s)
Epoch 177 (Time: 6.71s)
Epoch 178 (Time: 6.69s)
Epoch 179 (Time: 6.68s)
Epoch 180 (Time: 6.71s)
Epoch 181 (Time: 6.70s)
Epoch 182 (Time: 6.71s)
Epoch 183 (Time: 6.72s)
Epoch 184 (Time: 6.70s)
Epoch 185 (Time: 6.69s)
Epoch 186 (Time: 6.68s)
Epoch 187 (Time: 6.70s)
Epoch 188 (Time: 6.68s)
Epoch 189 (Time: 6.68s)
Epoch 190 (Time: 6.68s)
Epoch 191 (Time: 6.66s)
Epoch 192 (Time: 6.69s)
Epoch 193 (Time: 6.68s)
Epoch 194 (Time: 6.69s)
Epoch 195 (Time: 6.71s)
Epoch 196 (Time: 6.73s)
Epoch 197 (Time: 6.72s)
Epoch 198 (Time: 6.70s)
Epoch 199 (Time: 6.68s)
Epoch 200 (Time: 6.69s)
Finish training for 200 epochs



png

比较 DnCNN UDnCNN DUDnCNN

# DnCNN
exp1.evaluate()
{'PSNR': tensor(29.0859), 'loss': 0.00509307918138802}



# UDnCNN
exp2.evaluate()
{'PSNR': tensor(28.3015), 'loss': 0.0060443114023655655}



# DUDnCNN
exp3.evaluate()
{'PSNR': tensor(29.1659), 'loss': 0.005009378213435412}



num = 3
img = []
nets = [exp1.net, exp2.net, exp3.net]
titles = ['noise','DnCNN', 'UDnCNN', 'DUDnCNN']

fig, axes = plt.subplots(nrows=num, ncols=4, figsize=(20,15), sharex='all', sharey='all')

for i in range(num):
    myimshow(test_set[7*i+7][0], ax=axes[i][0])
    x, _ = test_set[7*i+7]
    x = x.unsqueeze(0).to(device)
    img.append(x)

for i in range(num):
    for j in range(len(nets)):
        
        model = nets[j].to(device)
        model.eval()
        with torch.no_grad():
            y = model.forward(img[i])
        myimshow(y[0], ax=axes[i][j+1])
for i in range(num):
    for j in range(len(titles)):
        axes[i][j].set_title(f'{titles[j]}')

png

DUDnCNN 网络参数

exp3.net
DUDnCNN(
  (mse): MSELoss()
  (conv): ModuleList(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (bn): ModuleList(
    (0): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (4): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
    (5): BatchNorm2d(64, eps=64, momentum=0.1, affine=True, track_running_stats=True)
  )
)



for name, param in exp3.net.named_parameters():
    print(name, param.size(), param.requires_grad)
conv.0.weight torch.Size([64, 3, 3, 3]) True
conv.0.bias torch.Size([64]) True
conv.1.weight torch.Size([64, 64, 3, 3]) True
conv.1.bias torch.Size([64]) True
conv.2.weight torch.Size([64, 64, 3, 3]) True
conv.2.bias torch.Size([64]) True
conv.3.weight torch.Size([64, 64, 3, 3]) True
conv.3.bias torch.Size([64]) True
conv.4.weight torch.Size([64, 64, 3, 3]) True
conv.4.bias torch.Size([64]) True
conv.5.weight torch.Size([64, 64, 3, 3]) True
conv.5.bias torch.Size([64]) True
conv.6.weight torch.Size([64, 64, 3, 3]) True
conv.6.bias torch.Size([64]) True
conv.7.weight torch.Size([3, 64, 3, 3]) True
conv.7.bias torch.Size([3]) True
bn.0.weight torch.Size([64]) True
bn.0.bias torch.Size([64]) True
bn.1.weight torch.Size([64]) True
bn.1.bias torch.Size([64]) True
bn.2.weight torch.Size([64]) True
bn.2.bias torch.Size([64]) True
bn.3.weight torch.Size([64]) True
bn.3.bias torch.Size([64]) True
bn.4.weight torch.Size([64]) True
bn.4.bias torch.Size([64]) True
bn.5.weight torch.Size([64]) True
bn.5.bias torch.Size([64]) True

参数个数还是不变 3456 + 36864 x D

感受野 (Receptive Field) 计算:

$R_D=(1+2+\sum_{i=1}^{D/2}2^i\times 2^{i-1}+2^{D/2}\times 2^{D/2-1}+\sum_{i=1}^{D/2-1}2^i\times 2^{i-1}+2)^2$.

$R_6=89^2$.

  • 【数学】Jacobian 矩阵计算正交曲线坐标系单位矢量转换矩阵
  • 【流媒体】理想型的动漫流媒体平台 Crunchyroll 简评

仅有 1 条评论
  1. edisoncgh

    你这个自动生成的目录观感很差劲啊

    edisoncgh July 6th, 2021 at 08:03 pm回复
取消回复

说点什么?
Title
导训练集和测试集进来
loss 用的均方差
无权初始化
带权跑一下
这东西是要算平均的
导训练集和测试集进来
loss 用的均方差
无权初始化
带权跑一下
这东西是要算平均的

                                      ©JosePhilo | 鲁ICP备2020033872号-3