pyfusion.utils.utils 源代码

import math
import numpy as np
import numba
import cv2

from osgeo import gdal, gdalconst
from osgeo import osr
from bidict import bidict


dataTypeMap = bidict({
    gdal.GDT_Unknown : object,
    gdal.GDT_Byte : np.uint8,
    gdal.GDT_UInt16 : np.uint16,
    gdal.GDT_Int16 : np.int16,
    gdal.GDT_UInt32 : np.uint32,
    gdal.GDT_Int32 : np.int32,
    gdal.GDT_Float32 : np.float32,
    gdal.GDT_Float64 : np.float64,
})

[文档]@numba.jit def Entropy(src): """Calculate the information entropy of source image. The Entropy will calculate the information entropy of image :param src: source image, it should be read by `cv2.imread()` and grayscale image :return: information entropy :rtype: float """ tmp = [0.0] * 256 for i in range(src.shape[0]): for j in range(src.shape[1]): tmp[src[i,j]] += 1 for i in range(256): tmp[i] = tmp[i]/src.size result = 0 for i in range(256): if tmp[i] != 0: result -= tmp[i] * (np.log2(tmp[i])) return result
[文档]@numba.jit def avg_grad(src): """Calculate the average gradient of source image :param src: source image, it should be read by `cv2.imread()` and grayscale image :return: average gradient :rtype: float """ src = src.astype(np.float64) tmp = 0.0 for i in range(src.shape[0]-1): for j in range(src.shape[1]-1): dx = src[i, j+1] - src[i, j] dy = src[i+1, j] - src[i, j] tmp += math.sqrt((dx**2 + dy**2)/2) return tmp/(src.shape[0]*src.shape[1])
[文档]@numba.jit def RMSE(src, pan): """Calculate RMSE of source image :param src: source image, it should be read by `cv2.imread()` and grayscale image :return: RMSE :rtype: float """ src = src.astype(np.float64) pan = pan.astype(np.float64) tmp = 0.0 for i in range(src.shape[0]): for j in range(src.shape[1]): tmp += (src[i, j] - pan[i, j]) * (src[i, j] - pan[i, j]) return math.sqrt(tmp/(src.shape[0]*src.shape[1]))
[文档]def R_value(src, pan): """Calculate the R squared value :param src: source image, it should be read by `cv2.imread()` and grayscale image :return: R value :rtype: float """ mean_src = cv2.mean(src)[0] mean_pan = cv2.mean(pan)[0] src = src.astype(np.float64) pan = pan.astype(np.float64) return _Rsqure(src, pan, mean_src, mean_pan)
@numba.jit def _Rsqure(src, pan, src_mean, pan_mean): fenzi = 0 fenmu1 = 0 fenmu2 = 0 for i in range(src.shape[0]): for j in range(src.shape[1]): fenzi += (src[i, j] - src_mean) * (pan[i, j] - pan_mean) fenmu1 += (src[i, j] - src_mean) **2 fenmu2 += (pan[i, j] - pan_mean) **2 return fenzi / math.sqrt(fenmu1*fenmu2)
[文档]@numba.jit def piancha_relativepiancha(src, pan): src = src.astype(np.float64) pan = pan.astype(np.float64) piancha = 0 relativepiancha = 0 for i in range(src.shape[0]): for j in range(src.shape[1]): pixel = abs(src[i, j] - pan[i, j]) piancha += pixel if pan[i, j] == 0: continue relativepiancha += pixel/pan[i, j] return (piancha/(src.shape[0]*src.shape[1])), (relativepiancha/(src.shape[0]*src.shape[1]))
[文档]def scala_image(img, rg): if rg[0] > rg[1]: raise "range error" b,g,r = cv2.split(img) b = b.astype(np.float32) g = g.astype(np.float32) r = r.astype(np.float32) min_b, max_b = np.min(b), np.max(b) min_g, max_g = np.min(g), np.max(g) min_r, max_r = np.min(r), np.max(r) bs = ((b - min_b) / (max_b - min_b)) * (rg[1] - rg[0]) + rg[0] gs = ((g - min_g) / (max_g - min_g)) * (rg[1] - rg[0]) + rg[0] rs = ((r - min_r) / (max_r - min_r)) * (rg[1] - rg[0]) + rg[0] bs = bs.astype(np.uint8) gs = gs.astype(np.uint8) rs = rs.astype(np.uint8) return cv2.merge((bs, gs, rs))
[文档]class Image: """`Image` is a custom class. It wrap the `osgeo.gdal.Dataset` We add some feature on it. It can be accessed by slice, so you can read/write any pixel(s) directly. There is a attribute named `image`, it can access `osgeo.gdal.Dataset` object directly. """ def __init__(self, image=None): self._image= image # GDAL中的Dataset # self._data = None # 转为numpy.array的数据 # self._acdata = False # def toGeoTiff(self, filepath:str): # """Save image to disk # This fuction can save the image to GeoTiff file. There are some points should be attentioned. # The image file is based on `data` attribute. It means the `image` can only control the metadata and the value of the pixel is determined by the data. # The image width and height should be equal to the `data` shape. Also the band number is. # """ # # if self._image is None: # # raise Exception("No image to save") # # if self._data.shape != (self._image.RasterYSize, self._image.RasterXSize, self._image.RasterCount): # # raise Exception("data not match, need {} but get {}".format((self._image.RasterYSize, self._image.RasterXSize, self._image.RasterCount), self._data.shape)) # driver = gdal.GetDriverByName("GTiff") # tods = driver.Create(filepath, self._image.RasterXSize, self._image.RasterYSize, self._image.RasterCount, dataTypeMap.inverse[self._data.dtype.type], options=["INTERLEAVE=PIXEL"]) # tods.SetGeoTransform(self._image.GetGeoTransform()) # tods.SetProjection(self._image.GetProjection()) # # tods.SetMetadata(self._image.GetMetadata()) # for i in range(self._image.RasterCount): # band = tods.GetRasterBand(i+1) # band.WriteArray(self._data[:,:,i], 0, 0) # # tods.WriteRaster(0, 0, self._image.RasterXSize, self._image.RasterYSize, self._data[:,:,i].tostring(), self._image.RasterXSize, self._image.RasterYSize, band_list=[i+1]) @property def image(self): return self._image @image.deleter def image(self): self._image = None # @property # def data(self): # if self._data is not None: # return self._data # bn = self._image.RasterCount # xSize = self._image.RasterXSize # ySize = self._image.RasterYSize # if bn == 1: # self._data = self._image.GetRasterBand(1).ReadAsArray() # return self._data # self._data = np.empty((ySize, xSize, bn), dtype=dataTypeMap[self._image.GetRasterBand(1).DataType]) # for i in range(bn): # band = self._image.GetRasterBand(i+1) # self._data[:,:,i] = band.ReadAsArray() # return self._data # @data.setter # def data(self, data): # if data.shape != (self._image.RasterYSize, self._image.RasterXSize, self._image.RasterCount): # raise Exception("data not match, need {} but get {}".format((self._image.RasterYSize, self._image.RasterXSize, self._image.RasterCount), data.shape)) # self._data = data # @data.deleter # def data(self): # self._data = None def __getitem__(self, item): if self._image is None: raise ValueError("Image is None") dim = len(item) if not isinstance(item, tuple): raise TypeError("It must be a tuple") XSize = self._image.RasterXSize YSize = self._image.RasterYSize bn = self._image.RasterCount xoff, yoff = 0, 0 xsize, ysize = 0, 0 if dim == 2: if isinstance(item[0], int) and isinstance(item[1], int): return self._image.ReadAsArray(item[1], item[0], 1, 1, interleave="pixel") elif isinstance(item[0], slice) and isinstance(item[1], slice): if (item[0].step not in [None, 1]) and (item[1].step not in [None, 1]): raise ValueError("The slice step can not bigger than 1, we get {} and {}".format(item[0].step, item[1].step)) xoff = item[1].start if item[1].start != None else xoff yoff = item[0].start if item[0].start != None else yoff xsize = (item[1].stop - xoff) if item[1].stop != None else (XSize - xoff) ysize = (item[0].stop - yoff) if item[0].stop != None else (YSize - yoff) return self._image.ReadAsArray(xoff, yoff, xsize, ysize, interleave="pixel") elif isinstance(item[0], int) and isinstance(item[1], slice): if item[1].step not in [None, 1]: raise ValueError("The slice step can not bigger than 1, we get {}".format(item[1].step)) xoff = item[1].start if item[1].start != None else xoff xsize = (item[1].stop - xoff) if item[1].stop != None else (XSize - xoff) return self._image.ReadAsArray(xoff, item[0], xsize, 1, interleave="pixel") elif isinstance(item[0], slice) and isinstance(item[1], int): if item[0].step not in [None, 1]: raise ValueError("The slice step can not bigger than 1, we get {}".format(item[0].step)) yoff = item[0].start if item[0].start != None else yoff ysize = (item[0].stop - yoff) if item[0].stop != None else (YSize - yoff) return self._image.ReadAsArray(item[1], yoff, 1, ysize, interleave="pixel") else: raise ValueError("Ileagal slice, {}", item) else: raise NotImplementedError("Not implement!") def __setitem__(self, key, value:np.ndarray): data_type = self._image.GetRasterBand(1).DataType if dataTypeMap[data_type] != value.dtype.type: raise TypeError("type not match:", data_type, value.dtype.type) if self._image is None: raise ValueError("Image is None") if not isinstance(key, tuple): raise TypeError("It must be a tuple") dim = len(key) XSize = self._image.RasterXSize YSize = self._image.RasterYSize bn = self._image.RasterCount xoff, yoff = 0, 0 xsize, ysize = 0, 0 if dim == 2: if isinstance(key[0], int) and isinstance(key[1], int): # if not isinstance(value, np.array): # raise TypeError("get {}".format(type(value))) if value.shape != (1, 1, bn): raise ValueError("wrong shape:", value.shape) if self._image.WriteRaster(key[1], key[0], 1, 1, value.tobytes(), band_list=list(range(1,bn+1))): raise Exception(gdal.GetLastErrorMsg()) elif isinstance(key[0], slice) and isinstance(key[1], slice): xoff = key[1].start if key[1].start != None else xoff yoff = key[0].start if key[0].start != None else yoff xsize = (key[1].stop - xoff) if key[1].stop != None else (XSize - xoff) ysize = (key[0].stop - yoff) if key[0].stop != None else (YSize - yoff) if (xoff > XSize) or (yoff > YSize) or (xsize > (XSize-xoff)) or (ysize > (YSize-yoff)): raise ValueError("Shape error:", xoff, yoff, xoff+xsize, yoff+ysize) if (key[0].step not in [None, 1]) and (key[1].step not in [None, 1]): raise ValueError("The slice step can not bigger than 1, we get {} and {}".format(key[0].step, key[1].step)) if value.shape != (ysize, xsize, bn): raise ValueError("wrong shape:", value.shape) value = np.vstack([x.reshape((1, ysize, xsize)) for x in np.split(value, bn, axis=2)]) if self._image.WriteRaster(xoff, yoff, xsize, ysize, value.tobytes(), xsize, ysize, band_list=list(range(1,bn+1))): raise Exception(gdal.GetLastErrorMsg()) elif isinstance(key[0], int) and isinstance(key[1], slice): xoff = key[1].start if key[1].start != None else xoff xsize = (key[1].stop - xoff) if key[1].stop != None else (XSize - xoff) if (xoff > XSize) or (xsize > (XSize-xoff)): raise ValueError("Shape error:", xoff, xoff+xsize) if key[1].step not in [None, 1]: raise ValueError("The slice step can not bigger than 1, we get {}".format(key[1].step)) if value.shape != (1, xsize, bn): raise ValueError("wrong shape:", value.shape) if self._image.WriteRaster(xoff, key[0], xsize, 1, value.tobytes(), band_list=list(range(1,bn+1))): raise Exception(gdal.GetLastErrorMsg()) elif isinstance(key[0], slice) and isinstance(key[1], int): yoff = key[0].start if key[0].start != None else yoff ysize = (key[0].stop - yoff) if key[0].stop != None else (YSize - yoff) if (yoff > YSize) or (ysize > (YSize - yoff)): raise ValueError("Shape error:", yoff, yoff+ysize) if key[0].step not in [None, 1]: raise ValueError("The slice step can not bigger than 1, we get {}".format(key[0].step)) if value.shape != (ysize, 1, bn): raise ValueError("wrong shape:", value.shape) if self._image.WriteRaster(key[1], yoff, 1, ysize, value.tobytes(), band_list=list(range(1,bn+1))): raise Exception(gdal.GetLastErrorMsg()) else: raise ValueError("Ileagal slice, {}", key) else: raise NotImplementedError("Not implement!")
[文档]def ReadImageFile(filepath:str): """Read image It will read the image using `gdal.OpenShared` in read-only mode. :param filepath: The file path :return: Image :rtype: `pyfusion.utils.Image` """ dataset = gdal.OpenShared(filepath, gdal.GA_ReadOnly) img = Image(image=dataset) return img
def DoesDriverHandleExtension(drv, ext): exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS) return exts is not None and exts.lower().find(ext.lower()) >= 0 def GetExtension(filename): ext = os.path.splitext(filename)[1] if ext.startswith('.'): ext = ext[1:] return ext def GetOutputDriversFor(filename): drv_list = [] ext = GetExtension(filename) for i in range(gdal.GetDriverCount()): drv = gdal.GetDriver(i) if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \ drv.GetMetadataItem(gdal.DCAP_RASTER) is not None: if len(ext) > 0 and DoesDriverHandleExtension(drv, ext): drv_list.append(drv.ShortName) else: prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX) if prefix is not None and filename.lower().startswith(prefix.lower()): drv_list.append(drv.ShortName) # GMT is registered before netCDF for opening reasons, but we want # netCDF to be used by default for output. if ext.lower() == 'nc' and len(drv_list) == 0 and \ drv_list[0].upper() == 'GMT' and drv_list[1].upper() == 'NETCDF': drv_list = ['NETCDF', 'GMT'] return drv_list
[文档]def GetOutputDriverFor(filename): drv_list = GetOutputDriversFor(filename) if len(drv_list) == 0: ext = GetExtension(filename) if len(ext) == 0: return 'GTiff' else: raise Exception("Cannot guess driver for %s" % filename) elif len(drv_list) > 1: print("Several drivers matching %s extension. Using %s" % (ext, drv_list[0])) return drv_list[0]