import math
import numpy as np
import numba
import cv2
from osgeo import gdal, gdalconst
from osgeo import osr
from bidict import bidict
dataTypeMap = bidict({
gdal.GDT_Unknown : object,
gdal.GDT_Byte : np.uint8,
gdal.GDT_UInt16 : np.uint16,
gdal.GDT_Int16 : np.int16,
gdal.GDT_UInt32 : np.uint32,
gdal.GDT_Int32 : np.int32,
gdal.GDT_Float32 : np.float32,
gdal.GDT_Float64 : np.float64,
})
[文档]@numba.jit
def Entropy(src):
"""Calculate the information entropy of source image.
The Entropy will calculate the information entropy of image
:param src: source image, it should be read by `cv2.imread()` and grayscale image
:return: information entropy
:rtype: float
"""
tmp = [0.0] * 256
for i in range(src.shape[0]):
for j in range(src.shape[1]):
tmp[src[i,j]] += 1
for i in range(256):
tmp[i] = tmp[i]/src.size
result = 0
for i in range(256):
if tmp[i] != 0:
result -= tmp[i] * (np.log2(tmp[i]))
return result
[文档]@numba.jit
def avg_grad(src):
"""Calculate the average gradient of source image
:param src: source image, it should be read by `cv2.imread()` and grayscale image
:return: average gradient
:rtype: float
"""
src = src.astype(np.float64)
tmp = 0.0
for i in range(src.shape[0]-1):
for j in range(src.shape[1]-1):
dx = src[i, j+1] - src[i, j]
dy = src[i+1, j] - src[i, j]
tmp += math.sqrt((dx**2 + dy**2)/2)
return tmp/(src.shape[0]*src.shape[1])
[文档]@numba.jit
def RMSE(src, pan):
"""Calculate RMSE of source image
:param src: source image, it should be read by `cv2.imread()` and grayscale image
:return: RMSE
:rtype: float
"""
src = src.astype(np.float64)
pan = pan.astype(np.float64)
tmp = 0.0
for i in range(src.shape[0]):
for j in range(src.shape[1]):
tmp += (src[i, j] - pan[i, j]) * (src[i, j] - pan[i, j])
return math.sqrt(tmp/(src.shape[0]*src.shape[1]))
[文档]def R_value(src, pan):
"""Calculate the R squared value
:param src: source image, it should be read by `cv2.imread()` and grayscale image
:return: R value
:rtype: float
"""
mean_src = cv2.mean(src)[0]
mean_pan = cv2.mean(pan)[0]
src = src.astype(np.float64)
pan = pan.astype(np.float64)
return _Rsqure(src, pan, mean_src, mean_pan)
@numba.jit
def _Rsqure(src, pan, src_mean, pan_mean):
fenzi = 0
fenmu1 = 0
fenmu2 = 0
for i in range(src.shape[0]):
for j in range(src.shape[1]):
fenzi += (src[i, j] - src_mean) * (pan[i, j] - pan_mean)
fenmu1 += (src[i, j] - src_mean) **2
fenmu2 += (pan[i, j] - pan_mean) **2
return fenzi / math.sqrt(fenmu1*fenmu2)
[文档]@numba.jit
def piancha_relativepiancha(src, pan):
src = src.astype(np.float64)
pan = pan.astype(np.float64)
piancha = 0
relativepiancha = 0
for i in range(src.shape[0]):
for j in range(src.shape[1]):
pixel = abs(src[i, j] - pan[i, j])
piancha += pixel
if pan[i, j] == 0:
continue
relativepiancha += pixel/pan[i, j]
return (piancha/(src.shape[0]*src.shape[1])), (relativepiancha/(src.shape[0]*src.shape[1]))
[文档]def scala_image(img, rg):
if rg[0] > rg[1]:
raise "range error"
b,g,r = cv2.split(img)
b = b.astype(np.float32)
g = g.astype(np.float32)
r = r.astype(np.float32)
min_b, max_b = np.min(b), np.max(b)
min_g, max_g = np.min(g), np.max(g)
min_r, max_r = np.min(r), np.max(r)
bs = ((b - min_b) / (max_b - min_b)) * (rg[1] - rg[0]) + rg[0]
gs = ((g - min_g) / (max_g - min_g)) * (rg[1] - rg[0]) + rg[0]
rs = ((r - min_r) / (max_r - min_r)) * (rg[1] - rg[0]) + rg[0]
bs = bs.astype(np.uint8)
gs = gs.astype(np.uint8)
rs = rs.astype(np.uint8)
return cv2.merge((bs, gs, rs))
[文档]class Image:
"""`Image` is a custom class. It wrap the `osgeo.gdal.Dataset`
We add some feature on it. It can be accessed by slice, so you can read/write any pixel(s) directly.
There is a attribute named `image`, it can access `osgeo.gdal.Dataset` object directly.
"""
def __init__(self, image=None):
self._image= image # GDAL中的Dataset
# self._data = None # 转为numpy.array的数据
# self._acdata = False
# def toGeoTiff(self, filepath:str):
# """Save image to disk
# This fuction can save the image to GeoTiff file. There are some points should be attentioned.
# The image file is based on `data` attribute. It means the `image` can only control the metadata and the value of the pixel is determined by the data.
# The image width and height should be equal to the `data` shape. Also the band number is.
# """
# # if self._image is None:
# # raise Exception("No image to save")
# # if self._data.shape != (self._image.RasterYSize, self._image.RasterXSize, self._image.RasterCount):
# # raise Exception("data not match, need {} but get {}".format((self._image.RasterYSize, self._image.RasterXSize, self._image.RasterCount), self._data.shape))
# driver = gdal.GetDriverByName("GTiff")
# tods = driver.Create(filepath, self._image.RasterXSize, self._image.RasterYSize, self._image.RasterCount, dataTypeMap.inverse[self._data.dtype.type], options=["INTERLEAVE=PIXEL"])
# tods.SetGeoTransform(self._image.GetGeoTransform())
# tods.SetProjection(self._image.GetProjection())
# # tods.SetMetadata(self._image.GetMetadata())
# for i in range(self._image.RasterCount):
# band = tods.GetRasterBand(i+1)
# band.WriteArray(self._data[:,:,i], 0, 0)
# # tods.WriteRaster(0, 0, self._image.RasterXSize, self._image.RasterYSize, self._data[:,:,i].tostring(), self._image.RasterXSize, self._image.RasterYSize, band_list=[i+1])
@property
def image(self):
return self._image
@image.deleter
def image(self):
self._image = None
# @property
# def data(self):
# if self._data is not None:
# return self._data
# bn = self._image.RasterCount
# xSize = self._image.RasterXSize
# ySize = self._image.RasterYSize
# if bn == 1:
# self._data = self._image.GetRasterBand(1).ReadAsArray()
# return self._data
# self._data = np.empty((ySize, xSize, bn), dtype=dataTypeMap[self._image.GetRasterBand(1).DataType])
# for i in range(bn):
# band = self._image.GetRasterBand(i+1)
# self._data[:,:,i] = band.ReadAsArray()
# return self._data
# @data.setter
# def data(self, data):
# if data.shape != (self._image.RasterYSize, self._image.RasterXSize, self._image.RasterCount):
# raise Exception("data not match, need {} but get {}".format((self._image.RasterYSize, self._image.RasterXSize, self._image.RasterCount), data.shape))
# self._data = data
# @data.deleter
# def data(self):
# self._data = None
def __getitem__(self, item):
if self._image is None:
raise ValueError("Image is None")
dim = len(item)
if not isinstance(item, tuple):
raise TypeError("It must be a tuple")
XSize = self._image.RasterXSize
YSize = self._image.RasterYSize
bn = self._image.RasterCount
xoff, yoff = 0, 0
xsize, ysize = 0, 0
if dim == 2:
if isinstance(item[0], int) and isinstance(item[1], int):
return self._image.ReadAsArray(item[1], item[0], 1, 1, interleave="pixel")
elif isinstance(item[0], slice) and isinstance(item[1], slice):
if (item[0].step not in [None, 1]) and (item[1].step not in [None, 1]):
raise ValueError("The slice step can not bigger than 1, we get {} and {}".format(item[0].step, item[1].step))
xoff = item[1].start if item[1].start != None else xoff
yoff = item[0].start if item[0].start != None else yoff
xsize = (item[1].stop - xoff) if item[1].stop != None else (XSize - xoff)
ysize = (item[0].stop - yoff) if item[0].stop != None else (YSize - yoff)
return self._image.ReadAsArray(xoff, yoff, xsize, ysize, interleave="pixel")
elif isinstance(item[0], int) and isinstance(item[1], slice):
if item[1].step not in [None, 1]:
raise ValueError("The slice step can not bigger than 1, we get {}".format(item[1].step))
xoff = item[1].start if item[1].start != None else xoff
xsize = (item[1].stop - xoff) if item[1].stop != None else (XSize - xoff)
return self._image.ReadAsArray(xoff, item[0], xsize, 1, interleave="pixel")
elif isinstance(item[0], slice) and isinstance(item[1], int):
if item[0].step not in [None, 1]:
raise ValueError("The slice step can not bigger than 1, we get {}".format(item[0].step))
yoff = item[0].start if item[0].start != None else yoff
ysize = (item[0].stop - yoff) if item[0].stop != None else (YSize - yoff)
return self._image.ReadAsArray(item[1], yoff, 1, ysize, interleave="pixel")
else:
raise ValueError("Ileagal slice, {}", item)
else:
raise NotImplementedError("Not implement!")
def __setitem__(self, key, value:np.ndarray):
data_type = self._image.GetRasterBand(1).DataType
if dataTypeMap[data_type] != value.dtype.type:
raise TypeError("type not match:", data_type, value.dtype.type)
if self._image is None:
raise ValueError("Image is None")
if not isinstance(key, tuple):
raise TypeError("It must be a tuple")
dim = len(key)
XSize = self._image.RasterXSize
YSize = self._image.RasterYSize
bn = self._image.RasterCount
xoff, yoff = 0, 0
xsize, ysize = 0, 0
if dim == 2:
if isinstance(key[0], int) and isinstance(key[1], int):
# if not isinstance(value, np.array):
# raise TypeError("get {}".format(type(value)))
if value.shape != (1, 1, bn):
raise ValueError("wrong shape:", value.shape)
if self._image.WriteRaster(key[1], key[0], 1, 1, value.tobytes(), band_list=list(range(1,bn+1))):
raise Exception(gdal.GetLastErrorMsg())
elif isinstance(key[0], slice) and isinstance(key[1], slice):
xoff = key[1].start if key[1].start != None else xoff
yoff = key[0].start if key[0].start != None else yoff
xsize = (key[1].stop - xoff) if key[1].stop != None else (XSize - xoff)
ysize = (key[0].stop - yoff) if key[0].stop != None else (YSize - yoff)
if (xoff > XSize) or (yoff > YSize) or (xsize > (XSize-xoff)) or (ysize > (YSize-yoff)):
raise ValueError("Shape error:", xoff, yoff, xoff+xsize, yoff+ysize)
if (key[0].step not in [None, 1]) and (key[1].step not in [None, 1]):
raise ValueError("The slice step can not bigger than 1, we get {} and {}".format(key[0].step, key[1].step))
if value.shape != (ysize, xsize, bn):
raise ValueError("wrong shape:", value.shape)
value = np.vstack([x.reshape((1, ysize, xsize)) for x in np.split(value, bn, axis=2)])
if self._image.WriteRaster(xoff, yoff, xsize, ysize, value.tobytes(), xsize, ysize, band_list=list(range(1,bn+1))):
raise Exception(gdal.GetLastErrorMsg())
elif isinstance(key[0], int) and isinstance(key[1], slice):
xoff = key[1].start if key[1].start != None else xoff
xsize = (key[1].stop - xoff) if key[1].stop != None else (XSize - xoff)
if (xoff > XSize) or (xsize > (XSize-xoff)):
raise ValueError("Shape error:", xoff, xoff+xsize)
if key[1].step not in [None, 1]:
raise ValueError("The slice step can not bigger than 1, we get {}".format(key[1].step))
if value.shape != (1, xsize, bn):
raise ValueError("wrong shape:", value.shape)
if self._image.WriteRaster(xoff, key[0], xsize, 1, value.tobytes(), band_list=list(range(1,bn+1))):
raise Exception(gdal.GetLastErrorMsg())
elif isinstance(key[0], slice) and isinstance(key[1], int):
yoff = key[0].start if key[0].start != None else yoff
ysize = (key[0].stop - yoff) if key[0].stop != None else (YSize - yoff)
if (yoff > YSize) or (ysize > (YSize - yoff)):
raise ValueError("Shape error:", yoff, yoff+ysize)
if key[0].step not in [None, 1]:
raise ValueError("The slice step can not bigger than 1, we get {}".format(key[0].step))
if value.shape != (ysize, 1, bn):
raise ValueError("wrong shape:", value.shape)
if self._image.WriteRaster(key[1], yoff, 1, ysize, value.tobytes(), band_list=list(range(1,bn+1))):
raise Exception(gdal.GetLastErrorMsg())
else:
raise ValueError("Ileagal slice, {}", key)
else:
raise NotImplementedError("Not implement!")
[文档]def ReadImageFile(filepath:str):
"""Read image
It will read the image using `gdal.OpenShared` in read-only mode.
:param filepath: The file path
:return: Image
:rtype: `pyfusion.utils.Image`
"""
dataset = gdal.OpenShared(filepath, gdal.GA_ReadOnly)
img = Image(image=dataset)
return img
def DoesDriverHandleExtension(drv, ext):
exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
return exts is not None and exts.lower().find(ext.lower()) >= 0
def GetExtension(filename):
ext = os.path.splitext(filename)[1]
if ext.startswith('.'):
ext = ext[1:]
return ext
def GetOutputDriversFor(filename):
drv_list = []
ext = GetExtension(filename)
for i in range(gdal.GetDriverCount()):
drv = gdal.GetDriver(i)
if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
drv.GetMetadataItem(gdal.DCAP_RASTER) is not None:
if len(ext) > 0 and DoesDriverHandleExtension(drv, ext):
drv_list.append(drv.ShortName)
else:
prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
if prefix is not None and filename.lower().startswith(prefix.lower()):
drv_list.append(drv.ShortName)
# GMT is registered before netCDF for opening reasons, but we want
# netCDF to be used by default for output.
if ext.lower() == 'nc' and len(drv_list) == 0 and \
drv_list[0].upper() == 'GMT' and drv_list[1].upper() == 'NETCDF':
drv_list = ['NETCDF', 'GMT']
return drv_list
[文档]def GetOutputDriverFor(filename):
drv_list = GetOutputDriversFor(filename)
if len(drv_list) == 0:
ext = GetExtension(filename)
if len(ext) == 0:
return 'GTiff'
else:
raise Exception("Cannot guess driver for %s" % filename)
elif len(drv_list) > 1:
print("Several drivers matching %s extension. Using %s" % (ext, drv_list[0]))
return drv_list[0]