import cv2
import numpy as np
from .utils import histMatch, histMatch2, Image
import copy
from osgeo import gdal, gdalconst
from osgeo import osr
[docs]def PCA(src:Image, pan:Image):
"""PCA function
The PCA function is fusion the `src` with `pan` based on PCA
:param src: The source image, it should be multi-band
the image is `pyfusion.utils.Image`
:param pan: The Pan image, it should be single band
the image is `pyfusion.utils.Image`
:return: The result image
:rtype: `pyfusion.utils.Image`
"""
assert pan.image.RasterCount == 1, "pan image should be 1 band"
driver = gdal.GetDriverByName("VRT")
dstDs = driver.Create("", pan.image.RasterXSize, pan.image.RasterYSize, src.image.RasterCount, options=["INTERLEAVE=PIXEL"])
dstImage = Image(dstDs)
dstImage.data = _PCA(src.data, pan.data)
return dstImage
def _PCA(src, pan):
"""PCA function
The PCA function is fusion the `src` with `pan` based on PCA
:param src: The source image, it should be multi-band
the image read by `cv2.imread` will be the best practice
:param pan: The Pan image, it should be single band
the image read by `cv2.imread` with `flag` set to `cv2.IMREAD_GRAYSCALE` will be the best practice
:return: The result image
:rtype: numpy.ndarray
"""
n = pan.size
k = src.shape[2]
# shape = pan.shape
# shape = list(pan.shape)
# shape.append(k)
# mShape = tuple(shape)
# upsamply the image
src = cv2.resize(src, tuple(reversed(pan.shape)))
data = np.zeros((n, k))
# reshape the data
for i in range(k):
data[:,i:i+1] = src[:,:,i].reshape(n, 1)
# Calculate the PCA
mean, eigenvectors = cv2.PCACompute(data, None)
pca_transform = cv2.PCAProject(data, mean, eigenvectors)
# PC1
pc1 = pca_transform[:, 0:1]
# Histogram match
pan_new = histMatch2(pan, pc1)
# Replace
pca_transform[:, 0:1] = pan_new.reshape((n, 1))
# Reverse PCA
result = cv2.PCABackProject(pca_transform, mean, eigenvectors)
# merget
result_img = []
for i in range(k):
t = result[:,i]
t[t<0]=0
t[t>255]=255
result_img.append(t.reshape(pan.shape))
result_img = tuple(result_img)
result_img = cv2.merge(result_img)
return result_img.astype(np.uint8)