import cv2
import numpy as np
from .utils import histMatch, histMatch2, Image, ReadImageFile, GetOutputDriverFor
import copy
import pyarma as pa
from osgeo import gdal, gdalconst
from osgeo import osr
[文档]def PCAByFilePath(srcpath:str, panpath:str, resultpath:str, driver:str):
"""Using PCA method fusion image
The output only support tif format
:param srcpath: `source` image path
:param panpath: Panchromatic image path
:param resultpath: result image path
:return: None
:rtype: None
"""
PCA(ReadImageFile(srcpath), ReadImageFile(panpath), None, resultpath, driver)
[文档]def PCA(src:Image, pan:Image, block_size=None ,result_path=None, driver=None):
"""PCA function
The PCA function is fusion the `src` with `pan` based on PCA
:param src: The source image, it should be multi-band
the image is `pyfusion.utils.Image`
:param pan: The Pan image, it should be single band
the image is `pyfusion.utils.Image`
:param block_size: block_size
:param result_path: result image path, it can be any format supported by gdal
:param driver: gdal write driver
:return: The result image
:rtype: `pyfusion.utils.Image`
"""
assert pan.image.RasterCount == 1, "pan image should be 1 band"
if block_size is None:
block_size = 4096
dstDs = None
# if not result_path, it will store the result into memory
if result_path is None:
driver = gdal.GetDriverByName("MEM")
dstDs = driver.Create("", pan.image.RasterXSize, pan.image.RasterYSize, 3, options=["INTERLEAVE=PIXEL"])
# if give result_path, it will write result into file
elif driver is not None:
driver = gdal.GetDriverByName(driver)
dstDs = driver.Create(result_path, pan.image.RasterXSize, pan.image.RasterYSize, 3, options=["INTERLEAVE=PIXEL"])
else:
driver = GetOutputDriverFor(result_path)
dstDs = driver.Create(result_path, pan.image.RasterXSize, pan.image.RasterYSize, 3, options=["INTERLEAVE=PIXEL"])
dstDs.SetGeoTransform(pan.image.GetGeoTransform())
dstDs.SetProjection(pan.image.GetProjection())
dstDs.SetSpatialRef(pan.image.GetSpatialRef())
dstImage = Image(dstDs)
xsize = dstImage.image.RasterXSize
ysize = dstImage.image.RasterYSize
x_blocks = xsize // block_size
y_blocks = ysize // block_size
if x_blocks == 0:
x_blocks += 1
if y_blocks == 0:
y_blocks += 1
x_size_true = (xsize // x_blocks) + 1
y_size_true = (ysize // y_blocks) + 1
x_rate, y_rate = src.image.RasterXSize/xsize, src.image.RasterYSize/ysize
for i in range(x_blocks):
for j in range(y_blocks):
x_start, y_start = i*x_size_true, j*y_size_true
x_end, y_end = i*x_size_true+x_size_true, j*y_size_true+y_size_true
if x_end > xsize:
x_end = xsize
if y_end > ysize:
y_end = ysize
dstImage[y_start:y_end, x_start:x_end] = _PCA(src[int(y_start*y_rate):int(y_end*y_rate), int(x_start*x_rate):int(x_end*x_rate)], pan[y_start:y_end, x_start:x_end])
return dstImage
def _PCA(src, pan):
"""PCA function
The PCA function is fusion the `src` with `pan` based on PCA
:param src: The source image, it should be multi-band
the image read by `cv2.imread` will be the best practice
:param pan: The Pan image, it should be single band
the image read by `cv2.imread` with `flag` set to `cv2.IMREAD_GRAYSCALE` will be the best practice
:return: The result image
:rtype: numpy.ndarray
"""
n = pan.size
k = src.shape[2]
# upsamply the image
src = cv2.resize(src, tuple(reversed(pan.shape)))
if np.max(src) == 0:
return np.zeros(src.shape, dtype=np.uint8)
data = np.zeros((n, k))
# reshape the data
for i in range(k):
data[:,i:i+1] = src[:,:,i].reshape(n, 1)
# Calculate the PCA
# mean, eigenvectors = cv2.PCACompute(data, None)
# pca_transform = cv2.PCAProject(data, mean, eigenvectors)
# print(np.dot(data,eigenvectors))
# print("===========================================")
# print(pca_transform)
mean = np.mean(data, axis=0)
data_mat = pa.mat(data)
coeff, pca_transform = pa.mat(), pa.mat()
if not pa.princomp(coeff, pca_transform, data_mat):
raise RuntimeError("princomp error, data = {}".format(data_mat))
pca_transform = np.array(pca_transform)
coeff = np.array(coeff)
# PC1
pc1 = pca_transform[:, 0:1]
# print(np.tile(pan.reshape((n, 1)), (1,3)).shape)
# print(np.tile(pan.reshape((n, 1)), (1,3)))
# print(np.dot(np.tile(pan.reshape((n, 1)), (1,3)), coeff.T[0,:]).shape)
# print(np.dot(np.tile(pan.reshape((n, 1)), (1,3)), coeff.T[0,:]))
pan = pan.astype(np.float64)
pan_mean = np.mean(pan)
pan -= pan_mean
# print(pan)
# print("=================================")
# print(np.dot(np.tile(pan.reshape((n, 1)), (1,3)), coeff[:, 0]))
# print("=====================================")
# print(pc1)
# print("======================================")
# Histogram match
# pan_new = histMatch2(np.dot(np.tile(pan.reshape((n, 1)), (1,3)), coeff[:, 0]).reshape(pan.shape), pc1, np.sum(-pan_mean*coeff[:, 0]))
pan_new = histMatch2(np.dot(np.tile(pan.reshape((n, 1)), (1,3)), coeff[:, 0]).reshape(pan.shape), pc1, (np.sum(-pan_mean*coeff[:, 0]), np.sum(-mean*coeff[:, 0])))
# print(pca_transform[:, 0:1])
# print("===================================")
# print(pan_new.reshape((n, 1)))
# print("===================================")
# Replace
pca_transform[:, 0:1] = pan_new.reshape((n, 1))
# Reverse PCA
# result = cv2.PCABackProject(pca_transform, mean, eigenvectors)
result = np.dot(pca_transform, np.linalg.inv(coeff))
# print(result)
# print(result)
result += mean
result[data==0] = 0
# print("=========================")
# print(result)
# print("================================")
# merget
result_img = []
for i in range(k):
t = result[:,i]
t[t<0]=0
t[t>255]=255
result_img.append(t.reshape(pan.shape))
result_img = tuple(result_img)
result_img = cv2.merge(result_img)
# print(result_img)
return result_img.astype(np.uint8)