03. Reconstruction IIΒΆ

Convolution

IN PROGRESS

Degraded, Minimum-Norm, TV-Recon

Out:

/Users/lucasplagwitz/git_projects/recon/tutorials/convolution.py:89: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  norm = np.abs(np.asscalar((K.H*K).eigs(neigs=1, symmetric=True, largest=True, uselobpcg=True)))
0.350540804411
 Early stopping.
1.0
0.000191388322656
0.000122128998729
PSNR Noise: 22.87
PSNR Minimum-Norm: 18.84
PSNR TV-Recon: 26.34
MAE Noise: 0.0397541860084
MAE Minimum-Norm: 0.101340461986
MAE TV-Recon: 0.0264372818344

from pylops import Gradient

from recon.terms import BaseDataterm, IndicatorL2
from recon.solver import PdHgm
import matplotlib.pyplot as plt
from typing import Union
import skimage.data as skd
from scipy.fftpack import fft2, ifft2
from recon.utils import psnr

import numpy as np


def rgb2gray(rgb):

    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray


gt = rgb2gray(skd.coffee())[:,80:481]
gt = gt/np.max(gt)
sh = gt.shape
a = gt

size = 20
kernel_motion_blur = np.zeros((size, size))
kernel_motion_blur[int((size-1)/2), :] = np.ones(size)
kernel_motion_blur = kernel_motion_blur / size
h = kernel_motion_blur
kernel = h


def fftconvolve2d(x, y):
    # This assumes y is "smaller" than x.
    f2 = ifft2(fft2(x, shape=x.shape) * fft2(y, shape=x.shape)).real
    f2 = np.roll(f2, (-((y.shape[0] - 1)//2), -((y.shape[1] - 1)//2)), axis=(0, 1))
    return f2


f2 = fftconvolve2d(a, kernel)
f2 = f2 + np.random.normal(0, 0.01, size=f2.shape)

back = np.roll(f2, (((kernel.shape[0] - 1)//2), ((kernel.shape[1] - 1)//2)), axis=(0, 1))
back = fft2(back, shape=back.shape)


class DatanormL2Conv(BaseDataterm):


    def __init__(self,
                 image_size,
                 cop,
                 data: Union[float, np.ndarray] = 0,
                 lam: float = 1,
                 prox_param: float = 0.9,
                 sampling=None):

        operator = lambda x: fft2(x, shape=a.shape)
        super(DatanormL2Conv, self).__init__(operator, sampling=sampling, prox_param=prox_param)
        self.lam = lam
        self.data = data
        self.f_data = fft2(kernel, shape=a.shape)
        self.f_datah = fft2(kernel, shape=a.shape)
        self.inv_operator = lambda x: ifft2(x).real
        self.i = 0

    def prox(self, x):
        """Proximal Operator."""
        self.i += 1
        u = self.inv_operator(
                (self.operator(np.reshape(x, gt.shape)) + self.prox_param * self.lam *
                 (back) * self.f_data.conjugate()) / (1 + self.prox_param * self.lam *
                                                      self.diag_sampling * self.f_data*self.f_data.conjugate())
            )

        return u.ravel()


K = Gradient(gt.shape, edge=True, dtype='float64', kind='backward', sampling=1)
norm = np.abs(np.asscalar((K.H*K).eigs(neigs=1, symmetric=True, largest=True, uselobpcg=True)))
fac = 0.99
tau = fac * np.sqrt(1 / norm)
print(tau)
G = DatanormL2Conv(image_size=gt.shape, cop=kernel, data=f2, prox_param=tau, lam=100)
F_star = IndicatorL2(gt.shape, len(sh), prox_param=tau, upper_bound=1)

solver = PdHgm(K, F_star, G)
solver.max_iter = 3000
solver.tol = 1e-4
c = np.real(solver.solve())
x_tv = np.reshape(c, gt.shape)


# Proximal point algorithm for minimum norm solution.
x_old = np.zeros(gt.shape)
G = DatanormL2Conv(image_size=gt.shape, cop=kernel, data=f2, prox_param=tau, lam=1)
i = 0
while True:
    x_new = np.reshape(G.prox(x_old.ravel()), gt.shape)
    if i % 500 == 0:
        if np.linalg.norm(x_old.ravel()-x_new.ravel())/np.linalg.norm(x_new.ravel()) < 1e-4:
            break
        print(np.linalg.norm(x_old.ravel()-x_new.ravel())/np.linalg.norm(x_new.ravel()))
    x_old = x_new
    i += 1

fig, axs = plt.subplots(1, 3, figsize=(14, 5))
axs[0].imshow(np.reshape(f2, gt.shape), vmin=0, vmax=1)
axs[0].set_title('Degraded')
axs[0].axis('off')
axs[1].imshow(x_new)
axs[1].set_title('Minimum-Norm')
axs[1].axis('off')
axs[2].imshow(x_tv, vmin=0, vmax=1)
axs[2].set_title("TV-Recon")
axs[2].axis('off')
fig.tight_layout()
plt.show()

print("PSNR Noise: " + str(psnr(gt.ravel(), f2.ravel())))
print("PSNR Minimum-Norm: " + str(psnr(gt.ravel(), x_new.ravel())))
print("PSNR TV-Recon: " + str(psnr(gt.ravel(), x_tv.ravel())))
print("MAE Noise: " + str(np.sum(np.abs(f2-gt))/np.prod(gt.shape)))
print("MAE Minimum-Norm: " + str(np.sum(np.abs(f2-x_new))/np.prod(gt.shape)))
print("MAE TV-Recon: " + str(np.sum(np.abs(f2-x_tv))/np.prod(gt.shape)))

Total running time of the script: ( 2 minutes 10.589 seconds)

Gallery generated by Sphinx-Gallery