阅读量:0
如果用OpenCV-Python进行图像的离散傅里叶变换与逆变换其实还蛮简单的,流程就是上图所示,值得注意的是,如果是多通道的图像,譬如多光谱、高光谱图像,需要对每个通道都进行傅里叶变换,最后再聚合,如果只是RGB,可以用如下方式合成灰度图,只需要对灰度图做处理即可。
img1 = 0.2126 * image1[:,:,0] + 0.7152 * image1[:,:,1] + 0.0722 * image1[:,:,2]
import cv2 as cv import numpy as np import matplotlib.pyplot as plt # 测试图像 ori=cv.imread(r"F:\ori.jpg") # numpy 中的 fft 需要输入灰度图,我们需要将图像分割成不同的通道 def getRGBDFT(img): # cv2默认的图像通道是BGR,需要进行转换 img=cv.cvtColor(img,cv.COLOR_BGR2RGB) # 分离通道 r,g,b=cv.split(img) # 对每个通道进行傅里叶变换 f_r,f_g,f_b=_dft(r),_dft(g),_dft(b) # 组合通道,还是以bgr格式返回 return cv.merge([f_b,f_g,f_r]) def _dft(img): f=cv.dft(np.float32(img),flags=cv.DFT_COMPLEX_OUTPUT) # 计算幅度谱 magnitude=cv.magnitude(f[:,:,0],f[:,:,1]) # 对数变换增强对比度 res=np.log(magnitude+1) # 移动低频分量至中心 return np.fft.fftshift(res) out=getRGBDFT(ori) # 选取 plt.subplot(121),plt.imshow(ori[:,:,0],cmap='gray') plt.title("Ori"),plt.xticks([]),plt.yticks([]) plt.subplot(122),plt.imshow(out[:,:,0],cmap='gray') plt.title("Magnitude"),plt.xticks([]),plt.yticks([]) plt.show()
我们查看B通道的图像与傅里叶幅度谱:
接下来要进行傅里叶逆变换,代码如下:
def _idft(img): img=np.fft.ifftshift(img) img=cv.idft(img) return cv.magnitude(img[:,:,0],img[:,:,1])
若要在频域上做处理,可以添加掩膜:
def _Mask(img,type=None,d=2,size=4): row,col=img.shape[:-1] if type==None: return np.ones((row,col,d),np.uint8) if type == "LPF": mask=np.zeros((row,col,d),np.uint8) mask[row//size:row//size*(size-1),col//size:col//size*(size-1)]=1 elif type=="HPF": mask = np.ones((row, col,d), np.uint8) mask[row//size:row//size*(size-1),col//size:col//size*(size-1)] = 0 else: mask=np.ones((row,col,d),np.uint8) mask[row // 2-30:row // 2+30, col // 2-30:col // 2+30] = 0 return mask def _idft(img,mask=None): # img=_Mask(img,type)*img if mask!=None: img=mask*np.fft.ifftshift(img) else: img=np.fft.ifftshift(img) img=cv.idft(img) return cv.magnitude(img[:,:,0],img[:,:,1])
彩色图像结果:
如果想要用A图的高频细节替换B图,可以如下处理:
def _normal(img): a,b=img.max(),img.min() return np.clip((img-b)/(a-b),0,1) def swap(ori,aug): oriFFT=getRGBDFT(ori) augFFT=getRGBDFT(aug) HPF=_Mask(ori,"HPF",size=3) LPF=_Mask(ori,"LPF",size=16) res=[augFFT[i]*LPF+oriFFT[i]*HPF for i in range(len(oriFFT))] res=[_normal(_idft(i)) for i in res] return cv.merge(res)
完整代码如下:
import cv2 as cv import numpy as np import matplotlib.pyplot as plt # 测试图像 ori=cv.imread(r"F:\ori.jpg") aug=cv.imread(r"F:\129.jpg") # numpy 中的 fft 需要输入灰度图,我们需要将图像分割成不同的通道 def getRGBDFT(img): # cv2默认的图像通道是BGR,需要进行转换 img=cv.cvtColor(img,cv.COLOR_BGR2RGB) # 分离通道 r,g,b=cv.split(img) # 对每个通道进行傅里叶变换 f_r,f_g,f_b=_dft(r,False),_dft(g,False),_dft(b,False) # 组合通道,还是以bgr格式返回 # return cv.merge([f_b,f_g,f_r]) return [f_b,f_g,f_r] def _dft(img,to_show=True): f=cv.dft(np.float32(img),flags=cv.DFT_COMPLEX_OUTPUT) # 计算幅度谱 if to_show: # 对数变换增强对比度 magnitude = cv.magnitude(f[:, :, 0], f[:, :, 1]) f=np.log(magnitude+1) # 移动低频分量至中心 return np.fft.fftshift(f) def _idft(img,mask=None): # img=_Mask(img,type)*img if mask!=None: img=mask*np.fft.ifftshift(img) else: img=np.fft.ifftshift(img) img=cv.idft(img) return cv.magnitude(img[:,:,0],img[:,:,1]) def _Mask(img,type=None,d=2,size=4): row,col=img.shape[:-1] if type==None: return np.ones((row,col,d),np.uint8) if type == "LPF": mask=np.zeros((row,col,d),np.uint8) mask[row//size:row//size*(size-1),col//size:col//size*(size-1)]=1 elif type=="HPF": mask = np.ones((row, col,d), np.uint8) mask[row//size:row//size*(size-1),col//size:col//size*(size-1)] = 0 else: mask=np.ones((row,col,d),np.uint8) mask[row // 2-30:row // 2+30, col // 2-30:col // 2+30] = 0 return mask def _normal(img): a,b=img.max(),img.min() return np.clip((img-b)/(a-b),0,1) def getRGBIDFT(img,type): fft=getRGBDFT(img) ifft=[_normal(_idft(i,type)) for i in fft] return cv.merge(ifft) def swap(ori,aug): oriFFT=getRGBDFT(ori) augFFT=getRGBDFT(aug) HPF=_Mask(ori,"HPF",size=3) LPF=_Mask(ori,"LPF",size=4) res=[augFFT[i]*LPF+oriFFT[i]*HPF for i in range(len(oriFFT))] res=[_normal(_idft(i)) for i in res] return cv.merge(res) res=swap(ori,aug) plt.subplot(131),plt.imshow(ori) plt.title("Ori"),plt.xticks([]),plt.yticks([]) plt.subplot(132),plt.imshow(aug) plt.title("Aug"),plt.xticks([]),plt.yticks([]) plt.subplot(133),plt.imshow(res) plt.title("res"),plt.xticks([]),plt.yticks([]) plt.show()