Source code for pyfusion.PCA

import cv2
import numpy as np
from .utils import histMatch, histMatch2, Image
import copy

from osgeo import gdal, gdalconst
from osgeo import osr

[docs]def PCA(src:Image, pan:Image): """PCA function The PCA function is fusion the `src` with `pan` based on PCA :param src: The source image, it should be multi-band the image is `pyfusion.utils.Image` :param pan: The Pan image, it should be single band the image is `pyfusion.utils.Image` :return: The result image :rtype: `pyfusion.utils.Image` """ assert pan.image.RasterCount == 1, "pan image should be 1 band" driver = gdal.GetDriverByName("VRT") dstDs = driver.Create("", pan.image.RasterXSize, pan.image.RasterYSize, src.image.RasterCount, options=["INTERLEAVE=PIXEL"]) dstImage = Image(dstDs) dstImage.data = _PCA(src.data, pan.data) return dstImage
def _PCA(src, pan): """PCA function The PCA function is fusion the `src` with `pan` based on PCA :param src: The source image, it should be multi-band the image read by `cv2.imread` will be the best practice :param pan: The Pan image, it should be single band the image read by `cv2.imread` with `flag` set to `cv2.IMREAD_GRAYSCALE` will be the best practice :return: The result image :rtype: numpy.ndarray """ n = pan.size k = src.shape[2] # shape = pan.shape # shape = list(pan.shape) # shape.append(k) # mShape = tuple(shape) # upsamply the image src = cv2.resize(src, tuple(reversed(pan.shape))) data = np.zeros((n, k)) # reshape the data for i in range(k): data[:,i:i+1] = src[:,:,i].reshape(n, 1) # Calculate the PCA mean, eigenvectors = cv2.PCACompute(data, None) pca_transform = cv2.PCAProject(data, mean, eigenvectors) # PC1 pc1 = pca_transform[:, 0:1] # Histogram match pan_new = histMatch2(pan, pc1) # Replace pca_transform[:, 0:1] = pan_new.reshape((n, 1)) # Reverse PCA result = cv2.PCABackProject(pca_transform, mean, eigenvectors) # merget result_img = [] for i in range(k): t = result[:,i] t[t<0]=0 t[t>255]=255 result_img.append(t.reshape(pan.shape)) result_img = tuple(result_img) result_img = cv2.merge(result_img) return result_img.astype(np.uint8)