Source code for pyfusion.utils.utils

import math
import numpy as np
import numba
import cv2

from osgeo import gdal, gdalconst
from osgeo import osr
# from bidict import bidict


# dataTypeMap = bidict({
#     gdal.GDT_Unknown : object,
#     gdal.GDT_Byte : np.int8,
#     gdal.GDT_UInt16 : np.uint16,
#     gdal.GDT_Int16 : np.int16,
#     gdal.GDT_UInt32 : np.uint32,
#     gdal.GDT_Int32 : np.int32,
#     gdal.GDT_Float32 : np.float32,
#     gdal.GDT_Float64 : np.float64,
# })

[docs]@numba.jit def Entropy(src): """Calculate the information entropy of source image. The Entropy will calculate the information entropy of image :param src: source image, it should be read by `cv2.imread()` and grayscale image :return: information entropy :rtype: float """ tmp = [0.0] * 256 for i in range(src.shape[0]): for j in range(src.shape[1]): tmp[src[i,j]] += 1 for i in range(256): tmp[i] = tmp[i]/src.size result = 0 for i in range(256): if tmp[i] != 0: result -= tmp[i] * (np.log2(tmp[i])) return result
[docs]@numba.jit def avg_grad(src): """Calculate the average gradient of source image :param src: source image, it should be read by `cv2.imread()` and grayscale image :return: average gradient :rtype: float """ src = src.astype(np.float64) tmp = 0.0 for i in range(src.shape[0]-1): for j in range(src.shape[1]-1): dx = src[i, j+1] - src[i, j] dy = src[i+1, j] - src[i, j] tmp += math.sqrt((dx**2 + dy**2)/2) return tmp/(src.shape[0]*src.shape[1])
[docs]@numba.jit def RMSE(src, pan): """Calculate RMSE of source image :param src: source image, it should be read by `cv2.imread()` and grayscale image :return: RMSE :rtype: float """ src = src.astype(np.float64) pan = pan.astype(np.float64) tmp = 0.0 for i in range(src.shape[0]): for j in range(src.shape[1]): tmp += (src[i, j] - pan[i, j]) * (src[i, j] - pan[i, j]) return math.sqrt(tmp/(src.shape[0]*src.shape[1]))
[docs]def R_value(src, pan): """Calculate the R squared value :param src: source image, it should be read by `cv2.imread()` and grayscale image :return: R value :rtype: float """ mean_src = cv2.mean(src)[0] mean_pan = cv2.mean(pan)[0] src = src.astype(np.float64) pan = pan.astype(np.float64) return _Rsqure(src, pan, mean_src, mean_pan)
@numba.jit def _Rsqure(src, pan, src_mean, pan_mean): fenzi = 0 fenmu1 = 0 fenmu2 = 0 for i in range(src.shape[0]): for j in range(src.shape[1]): fenzi += (src[i, j] - src_mean) * (pan[i, j] - pan_mean) fenmu1 += (src[i, j] - src_mean) **2 fenmu2 += (pan[i, j] - pan_mean) **2 return fenzi / math.sqrt(fenmu1*fenmu2)
[docs]@numba.jit def piancha_relativepiancha(src, pan): src = src.astype(np.float64) pan = pan.astype(np.float64) piancha = 0 relativepiancha = 0 for i in range(src.shape[0]): for j in range(src.shape[1]): pixel = abs(src[i, j] - pan[i, j]) piancha += pixel if pan[i, j] == 0: continue relativepiancha += pixel/pan[i, j] return (piancha/(src.shape[0]*src.shape[1])), (relativepiancha/(src.shape[0]*src.shape[1]))
[docs]def scala_image(img, rg): if rg[0] > rg[1]: raise "range error" b,g,r = cv2.split(img) b = b.astype(np.float32) g = g.astype(np.float32) r = r.astype(np.float32) min_b, max_b = np.min(b), np.max(b) min_g, max_g = np.min(g), np.max(g) min_r, max_r = np.min(r), np.max(r) bs = ((b - min_b) / (max_b - min_b)) * (rg[1] - rg[0]) + rg[0] gs = ((g - min_g) / (max_g - min_g)) * (rg[1] - rg[0]) + rg[0] rs = ((r - min_r) / (max_r - min_r)) * (rg[1] - rg[0]) + rg[0] bs = bs.astype(np.uint8) gs = gs.astype(np.uint8) rs = rs.astype(np.uint8) return cv2.merge((bs, gs, rs))
[docs]class Image: def __init__(self, image=None): self._image= image # GDAL中的Dataset self._data = None # 转为numpy.array的数据
[docs] def toGeoTiff(self, filepath:str): if self._data is None: raise "No data" if self._data.shape != (self._image.RasterXSize, self._image.RasterYSize, self._image.RasterCount): raise "data not match, need {} but get {}".format((self._image.RasterXSize, self._image.RasterYSize, self._image.RasterCount), self._data.shape) driver = gdal.GetDriverByName("GTiff") tods = driver.Create(filepath, self._image.RasterXSize, self._image.RasterYSize, self._image.RasterCount, options=["INTERLEAVE=PIXEL"]) tods.SetGeoTransform(self._image.GetGeoTransform()) for i in range(self._image.RasterCount): tods.WriteRaster(0, 0, self._image.RasterXSize, self._image.RasterYSize, self._data[:,:,i].tostring(), self._image.RasterXSize, self._image.RasterYSize, band_list=[i+1])
@property def image(self): return self._image @image.deleter def image(self): self._image = None @property def data(self): if self._data is not None: return self._data bn = self._image.RasterCount xSize = self._image.RasterXSize ySize = self._image.RasterYSize if bn == 1: self._data = self._image.GetRasterBand(1).ReadAsArray(buf_type=gdalconst.GDT_Byte) return self._data self._data = np.empty((ySize, xSize, bn), dtype=np.uint8) for i in range(bn): band = self._image.GetRasterBand(i+1) self._data[:,:,i] = band.ReadAsArray(buf_type=gdalconst.GDT_Byte) return self._data @data.setter def data(self, data): if data.shape != (self._image.RasterXSize, self._image.RasterYSize, self._image.RasterCount): raise "data not match, need {} but get {}".format((self._image.RasterXSize, self._image.RasterYSize, self._image.RasterCount), data.shape) self._data = data @data.deleter def data(self): self._data = None
[docs]def ReadImageFile(filepath): dataset = gdal.OpenShared(filepath, gdal.GA_ReadOnly) img = Image(image=dataset) return img