基于字典學習的圖像去噪?

一個列子比較了浣熊人臉圖像噪聲碎片重構效果的比較, 首先使用在線詞典學習和各種轉換方法。

字典用來擬合圖像的左半部,然后用來重建右半部分。請注意,更好的性能可以通過擬合一個不失真(即無噪音)圖像來實現,但在這里,我們從假設它是不可用的。

評價圖像去噪效果的一個常見方法是通過觀察重建圖像與原始圖像的差異來評價圖像去噪效果。如果重建是完美的,這將看起來像高斯噪聲。

從圖中可以看出,具有兩個非零系數的正交匹配追蹤(OMP)的結果比只保持一個(邊界看起來不那么突出)的結果有一點偏差。

最小角回歸的結果具有更強的偏差:這種差異使人聯想到原始圖像的局部強度值。

閾值處理顯然對去噪沒有幫助,但在這里表明,它能夠以非常高的速度產生暗示性的輸出,因此對其他任務(如目標分類)非常有用,在這些任務中,性能不一定與可視化有關。

Distorting image...
Extracting reference patches...
done in 0.01s.
Learning the dictionary...
done in 3.70s.
Extracting noisy patches...
done in 0.00s.
Orthogonal Matching Pursuit
1 atom...
done in 0.96s.
Orthogonal Matching Pursuit
2 atoms...
done in 2.18s.
Least-angle regression
5 atoms...
done in 21.14s.
Thresholding
 alpha=0.1...
done in 0.14s.
print(__doc__)

from time import time

import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

from sklearn.decomposition import MiniBatchDictionaryLearning
from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.feature_extraction.image import reconstruct_from_patches_2d


try:  # SciPy >= 0.16 have face in misc
    from scipy.misc import face
    face = face(gray=True)
except ImportError:
    face = sp.face(gray=True)

# Convert from uint8 representation with values between 0 and 255 to
# a floating point representation with values between 0 and 1.
face = face / 255.

# downsample for higher speed
face = face[::4, ::4] + face[1::4, ::4] + face[::41::4] + face[1::41::4]
face /= 4.0
height, width = face.shape

# Distort the right half of the image
print('Distorting image...')
distorted = face.copy()
distorted[:, width // 2:] += 0.075 * np.random.randn(height, width // 2)

# Extract all reference patches from the left half of the image
print('Extracting reference patches...')
t0 = time()
patch_size = (77)
data = extract_patches_2d(distorted[:, :width // 2], patch_size)
data = data.reshape(data.shape[0], -1)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)
print('done in %.2fs.' % (time() - t0))

# #############################################################################
# Learn the dictionary from reference patches

print('Learning the dictionary...')
t0 = time()
dico = MiniBatchDictionaryLearning(n_components=100, alpha=1, n_iter=500)
V = dico.fit(data).components_
dt = time() - t0
print('done in %.2fs.' % dt)

plt.figure(figsize=(4.24))
for i, comp in enumerate(V[:100]):
    plt.subplot(1010, i + 1)
    plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r,
               interpolation='nearest')
    plt.xticks(())
    plt.yticks(())
plt.suptitle('Dictionary learned from face patches\n' +
             'Train time %.1fs on %d patches' % (dt, len(data)),
             fontsize=16)
plt.subplots_adjust(0.080.020.920.850.080.23)


# #############################################################################
# Display the distorted image

def show_with_diff(image, reference, title):
    """Helper function to display denoising"""
    plt.figure(figsize=(53.3))
    plt.subplot(121)
    plt.title('Image')
    plt.imshow(image, vmin=0, vmax=1, cmap=plt.cm.gray,
               interpolation='nearest')
    plt.xticks(())
    plt.yticks(())
    plt.subplot(122)
    difference = image - reference

    plt.title('Difference (norm: %.2f)' % np.sqrt(np.sum(difference ** 2)))
    plt.imshow(difference, vmin=-0.5, vmax=0.5, cmap=plt.cm.PuOr,
               interpolation='nearest')
    plt.xticks(())
    plt.yticks(())
    plt.suptitle(title, size=16)
    plt.subplots_adjust(0.020.020.980.790.020.2)

show_with_diff(distorted, face, 'Distorted image')

# #############################################################################
# Extract noisy patches and reconstruct them using the dictionary

print('Extracting noisy patches... ')
t0 = time()
data = extract_patches_2d(distorted[:, width // 2:], patch_size)
data = data.reshape(data.shape[0], -1)
intercept = np.mean(data, axis=0)
data -= intercept
print('done in %.2fs.' % (time() - t0))

transform_algorithms = [
    ('Orthogonal Matching Pursuit\n1 atom''omp',
     {'transform_n_nonzero_coefs'1}),
    ('Orthogonal Matching Pursuit\n2 atoms''omp',
     {'transform_n_nonzero_coefs'2}),
    ('Least-angle regression\n5 atoms''lars',
     {'transform_n_nonzero_coefs'5}),
    ('Thresholding\n alpha=0.1''threshold', {'transform_alpha'.1})]

reconstructions = {}
for title, transform_algorithm, kwargs in transform_algorithms:
    print(title + '...')
    reconstructions[title] = face.copy()
    t0 = time()
    dico.set_params(transform_algorithm=transform_algorithm, **kwargs)
    code = dico.transform(data)
    patches = np.dot(code, V)

    patches += intercept
    patches = patches.reshape(len(data), *patch_size)
    if transform_algorithm == 'threshold':
        patches -= patches.min()
        patches /= patches.max()
    reconstructions[title][:, width // 2:] = reconstruct_from_patches_2d(
        patches, (height, width // 2))
    dt = time() - t0
    print('done in %.2fs.' % dt)
    show_with_diff(reconstructions[title], face,
                   title + ' (time: %.1fs)' % dt)

plt.show()

腳本的總運行時間:(0分30.682秒)

Download Python source code: plot_image_denoising.py

Download Jupyter notebook: plot_image_denoising.ipynb