Note
Click here to download the full example code
04. Total Generalized VariationΒΆ
We take a step deeper into total-variation-based regularization.
We focus on concepts from different papers. Mainly we use for numerical access:
Knoll, Bredis, Pock: Second Order Total Generalized Variation (TGV) for MRI
The first order total variation got some stair-casing problems. See the following denoising example with the TV regularization.
import numpy as np
import matplotlib.pyplot as plt
from recon.utils import psnr
from recon.utils.images import two_smooth_squares
from recon.interfaces import Smoothing, SmoothBregman
image = two_smooth_squares(256, 128)
noise_image = image + np.random.normal(0, 0.2*np.max(image), size=image.shape)
tv_denoising = Smoothing(domain_shape=image.shape, reg_mode='tv', lam=0.3, alpha=0.1, tau='calc')
tv_solution = tv_denoising.solve(noise_image, max_iter=2000, tol=1e-4)
f = plt.figure(figsize=(6, 3))
f.add_subplot(1, 2, 1)
plt.gray()
plt.axis('off')
plt.imshow(noise_image, vmin=0, vmax=np.max(image))
plt.title("Noisy")
f.add_subplot(1, 2, 2)
plt.gray()
plt.imshow(tv_solution, vmin=0, vmax=np.max(image))
plt.title("TV based denoising")
plt.axis('off')
plt.show()
Out:
Early stopping.
To avoid the strong stair-casing effects, we introduce the total generalized variation (TGV). At this point there is no interface for second order TV. We implement it direct with an adapted Primal-Dual algorithm.
from recon.solver.pd_hgm_tgv import PdHgmTGV
# TGV smoothing small alpha
alpha = (0.3, 0.6)
solver = PdHgmTGV(alpha=alpha, lam=0.9)
tgv_solution = np.reshape(solver.solve(noise_image), image.shape)
f = plt.figure(figsize=(9, 3))
f.add_subplot(1, 3, 1)
plt.gray()
plt.axis('off')
plt.imshow(image, vmin=0, vmax=np.max(image))
plt.title("Original")
f.add_subplot(1, 3, 2)
plt.gray()
plt.axis('off')
plt.imshow(tv_solution, vmin=0, vmax=np.max(image))
plt.title("TV based denoising")
f.add_subplot(1, 3, 3)
plt.gray()
plt.imshow(tgv_solution, vmin=0, vmax=np.max(image))
plt.title("TGV based denoising")
plt.axis('off')
plt.show()
Out:
0.000805638699629
0.000407663334895
0.000267929634266
0.000221235935128
0.000170297437531
0.000140454748164
0.000123419405093
9.56346020134e-05
Since TGV also represents a convex functional, it can also be extended by Bregman. Maybe there will be an interface for this in the future.
plot_iteration = False
lam = 0.3
assessment = 0.2 * np.max(image) * np.sqrt(np.prod(noise_image.shape))
pk = np.zeros(image.shape)
pk = pk.ravel()
i = 0
u = np.zeros(image.shape)
while True:
print("current norm error: " + str(np.linalg.norm(u.ravel() - noise_image.ravel(), 2)))
print("runs till norm <: " + str(assessment))
solver = PdHgmTGV(alpha=alpha, lam=lam, mode='tgv', pk=pk)
u_new = np.reshape(solver.solve(noise_image), image.shape)
if np.linalg.norm(u_new.ravel() - noise_image.ravel(), 2) < assessment:
break
u = u_new
pk = pk - lam / alpha[0] * (u.ravel() - noise_image.ravel())
i = i + 1
if plot_iteration:
plt.gray()
plt.imshow(u)
plt.axis('off')
plt.savefig('Bregman_TGV_iter' + str(i) + '.png', bbox_inches='tight', pad_inches=0)
plt.close()
Out:
current norm error: 165.813232939
runs till norm <: 51.2
0.00100149905123
0.000345289902751
0.000209819486021
0.000124473937056
0.000103793210188
8.1683419529e-05
current norm error: 52.6711913506
runs till norm <: 51.2
0.00105654291953
0.000508841076177
0.000343839388152
0.000198076984105
0.000191136825702
0.000145567166102
0.000120249541513
9.37490182809e-05
Compare it to normal BTV.
breg_smoothing = SmoothBregman(domain_shape=image.shape,
reg_mode='tv',
alpha=1,
lam=0.5,
tau='calc',
plot_iteration=False,
assessment=assessment)
u_breg = breg_smoothing.solve(data=noise_image, max_iter=2000, tol=1e-4)
f = plt.figure(figsize=(9, 3))
f.add_subplot(1, 3, 1)
plt.gray()
plt.axis('off')
plt.imshow(image, vmin=0, vmax=np.max(image))
plt.title("Original")
f.add_subplot(1, 3, 2)
plt.gray()
plt.axis('off')
plt.imshow(np.reshape(u_breg, image.shape), vmin=0, vmax=np.max(image))
plt.title("BTV ")
f.add_subplot(1, 3, 3)
plt.gray()
plt.imshow(np.reshape(u_new, image.shape), vmin=0, vmax=np.max(image))
plt.title("BTGV")
plt.axis('off')
plt.show()
print("TV-PSNR: "+str(psnr(image, tv_solution)))
print("TGV-PSNR: "+str(psnr(image, tgv_solution)))
print("BTV-PSNR: "+str(psnr(image, u_breg)))
print("BTGV-PSNR: "+str(psnr(image, u_new)))
Out:
current norm error: 165.813232939
runs till norm <: 51.2
Early stopping.
current norm error: 54.0902926475
runs till norm <: 51.2
Early stopping.
TV-PSNR: 32.86
TGV-PSNR: 33.42
BTV-PSNR: 31.33
BTGV-PSNR: 36.89
Total running time of the script: ( 2 minutes 46.552 seconds)