上传文件至 'x64/Release/models/mask_fill'

This commit is contained in:
keyeslll 2023-03-13 16:24:23 +08:00
parent b163a03dc9
commit 1ca146c0c5
2 changed files with 616 additions and 0 deletions

View File

@ -0,0 +1,548 @@
import os
import cv2
import numpy as np
import tifIO
import shutil
import sys
import argparse
import math
from math import floor
import json
import subprocess
import qgis
from qgis.core import * # attach main QGIS library
from qgis.utils import * # attach main python library
from qgis.analysis import *
from osgeo import gdal
#相对路径
qgis_path = sys.executable
process_path = qgis_path.split('bin')[0]+'apps\\qgis-ltr\\python\\plugins'
# process_path = 'D:/code/RsSurvey/RsSurvey_Build/x64/Release/QGIS_3.22.8/apps/qgis-ltr/python/plugins/'
if os.path.exists(process_path):
print("exists path:"+process_path)
sys.stdout.flush()
else:
print("fine no path:"+process_path)
sys.stdout.flush()
sys.path.append(process_path)
qgs = QgsApplication([], False)
qgs.initQgis()
import processing
from processing.core.Processing import Processing
Processing.initialize()
#添加参数
parser = argparse.ArgumentParser()
parser.add_argument('--in_holedem', type=str,
default='D:/A_testdata/dianli_test/00yanshou/data_train_dem/05veg_maskedTiff/',
help='输入数据经过裁剪的dem')
parser.add_argument('--work_space', type=str,
default='D:/A_testdata/dianli_test/00yanshou/data_train_dem/06veg_dataFill_workspace/',
help='用于填补时存放中间路径及结果文件')
parser.add_argument('--scale', type=float,
default=0.1,
help='外接矩形放大的比例')
parser.add_argument('--win_size', type=int,
default=3,
help='空洞外扩像素')
opt = parser.parse_args()
in_holedem = opt.in_holedem
work_space = opt.work_space
scale = opt.scale
win_size = opt.win_size
'''
txt文件名以in_holedem中dem数据文件名命名
内容以singlehole文件开始包含位置信息
subname, lt_x,lt_y,rb_x,rb_y
tktk130_mask_id_1_hw_83_106.tif,1364,1950,1470,2033
'''
out_txt_folder = 'txt_singlehole/'
singlehole_folder = 'dem_singlehole/'# 保存单个空洞
ras2pot_folder = 'raster2points/'
ras2potpro_folder = 'raster2points_pro/'
tinmesh_folder = 'TINMesh/'
potfrommesh_folder = 'pointfrommesh/'
potfrommeshfloat_folder = 'pointfrommeshfloat/'
filltif_folder = 'tif/'
filltifpro_folder = 'tif_pro/'
out_fulltif_folder = 'out_result/'# 填补并放回的整图结果
print("win_size:%d,scale:%.4f" % (win_size, scale))
sys.stdout.flush()
def findHoles(in_demfolder, out_singlehole, txt_folder):
count = 0
ev = floor(win_size / 2)
if os.path.exists(txt_folder):
try:
shutil.rmtree(txt_folder)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print("error {}".format(str(e)))
sys.stdout.flush()
print(-1.0)
sys.stdout.flush()
os.mkdir(txt_folder)
if os.path.exists(out_singlehole):
try:
shutil.rmtree(out_singlehole)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print("error {}".format(str(e)))
sys.stdout.flush()
print(-1)
sys.stdout.flush()
os.mkdir(out_singlehole)
dem_list = []
src_list = os.listdir(in_demfolder)
for f in src_list:
if os.path.splitext(f)[1] == '.tif' or os.path.splitext(f)[1] == '.TIF':
dem_list.append(f)
for imgfile in dem_list:
filename = os.path.splitext(imgfile)[0]
txtname = filename + '.txt'
txt_route = os.path.join(txt_folder, txtname)
file_handle = open(txt_route, mode='w')
# file_handle.writelines(['文件名,左上角列号,左上角行号,右下角列号,右下角行号\n'])
dem_route = os.path.join(in_demfolder, imgfile)
proj, geotrans, dem_img, width, height = tifIO.ReadTif(dem_route) # c,h,w
origin = cv2.imread(dem_route, cv2.IMREAD_UNCHANGED)
img = origin.copy()
# 二值化,转换类型
ret2, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV)
thresh_image = binary.astype(np.uint8)
# 连通域分析
contours, hierarchy = cv2.findContours(thresh_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) # 检索最外轮廓
# contours, hierarchy = cv2.findContours(thresh_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) # 检索所有轮廓
for i in range(0, len(contours)):
thisgeo = geotrans.copy()
lt_x, lt_y, w, h = cv2.boundingRect(contours[i]) # 最小包络矩形的四至
# 矩形外扩,最小外扩2像元,保证空洞再小也能外扩
w_ex = floor(w * scale / 2) if floor(w * scale / 2) > 2 else 2
h_ex = floor(h * scale / 2) if floor(h * scale / 2) > 2 else 2
if lt_x - w_ex >= 0:
new_lt_x = lt_x - w_ex
else:
new_lt_x = 0
if lt_y - h_ex >= 0:
new_lt_y = lt_y - h_ex
else:
new_lt_y = 0
if lt_x + w + w_ex <= width:
new_rb_x = lt_x + w + w_ex
else:
new_rb_x = width
if lt_y + h + h_ex <= height:
new_rb_y = lt_y + h + h_ex
else:
new_rb_y = height
this_hole = img[new_lt_y: new_rb_y, new_lt_x: new_rb_x]
flag = np.zeros(this_hole.shape, dtype=int)
flag[this_hole > -9999] = 1 # 如果低于-9999arcgis中nodata是极小的负数
# 洞不扩大
# this_hole[this_hole <=0] = np.nan
# 洞外扩
new_w = new_rb_x - new_lt_x
new_h = new_rb_y - new_lt_y
for c in range(0, new_w):
for r in range(0, new_h):
if flag[r, c] == 0:
for win_c in range(win_size):
for win_r in range(win_size):
tempr = r + win_r - ev
tempc = c + win_c - ev
if tempr >= 0 and tempr < new_h and tempc >= 0 and tempc < new_w:
this_hole[tempr, tempc] = np.nan
lon = thisgeo[0] + thisgeo[1] * new_lt_x
lat = thisgeo[3] + thisgeo[5] * new_lt_y
thisgeo[0] = lon
thisgeo[3] = lat
# fileroutesub =filename+ '_%d.tif' % count
fileroutesub = filename + '_id_%d_hw_%d_%d' % (i + 1, new_h, new_w) + '.tif'
fileroute = os.path.join(out_singlehole, fileroutesub)
this_hole[np.where(this_hole == 0.0)] = np.nan
tifIO.writeTif(fileroute, proj, thisgeo, this_hole) # 投影信息不对
# 坐标从【00】开始计左上坐标【x,y】右下坐标[x+w-1,y+h-1]
file_handle.writelines([fileroutesub, ',', str(new_lt_x), ',',
str(new_lt_y), ',', str(new_rb_x), ',', str(new_rb_y), '\n'])
count = count + 1
file_handle.close()
def confirmWorkspaceClear(rootPath):
out_path_raster2points = os.path.join(rootPath, ras2pot_folder)
if os.path.exists(out_path_raster2points):
try:
shutil.rmtree(out_path_raster2points)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print(-1.0)
sys.stdout.flush()
os.mkdir(out_path_raster2points)
print(4)
sys.stdout.flush()
out_path_raster2points_pro = os.path.join(rootPath, ras2potpro_folder)
if os.path.exists(out_path_raster2points_pro):
try:
shutil.rmtree(out_path_raster2points_pro)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print(-1.0)
sys.stdout.flush()
os.mkdir(out_path_raster2points_pro)
print(5)
sys.stdout.flush()
out_path_TINMesh = os.path.join(rootPath, tinmesh_folder)
if os.path.exists(out_path_TINMesh):
try:
shutil.rmtree(out_path_TINMesh)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print(-1.0)
sys.stdout.flush()
os.mkdir(out_path_TINMesh)
print(6)
sys.stdout.flush()
out_path_pointfrommesh = os.path.join(rootPath, potfrommesh_folder)
if os.path.exists(out_path_pointfrommesh):
try:
shutil.rmtree(out_path_pointfrommesh)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print(-1.0)
sys.stdout.flush()
os.mkdir(out_path_pointfrommesh)
print(7)
sys.stdout.flush()
out_path_pointfrommeshfloat = os.path.join(rootPath, potfrommeshfloat_folder)
if os.path.exists(out_path_pointfrommeshfloat):
try:
shutil.rmtree(out_path_pointfrommeshfloat)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print(-1.0)
sys.stdout.flush()
os.mkdir(out_path_pointfrommeshfloat)
print(8)
sys.stdout.flush()
out_path_tif = os.path.join(rootPath, filltif_folder)
if os.path.exists(out_path_tif):
try:
shutil.rmtree(out_path_tif)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print(-1.0)
sys.stdout.flush()
os.mkdir(out_path_tif)
print(9)
sys.stdout.flush()
out_path_tif_pro = os.path.join(rootPath, filltifpro_folder)
if os.path.exists(out_path_tif_pro):
try:
shutil.rmtree(out_path_tif_pro)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print(-1.0)
sys.stdout.flush()
os.mkdir(out_path_tif_pro)
print(10)
sys.stdout.flush()
def raster2points(fileName):
inFile=work_space+singlehole_folder+fileName
# 获取文件信息
rlayer = QgsRasterLayer(inFile, "SRTM layer name")
tif_crs =rlayer.crs().toWkt()
if not rlayer.isValid():
print("图层加载失败!")
dataset = gdal.Open(inFile)
gt = dataset.GetGeoTransform()
Xmin = gt[0]
Ymin = gt[3]
width = dataset.RasterXSize
height = dataset.RasterYSize
Xmax = gt[0] + width*gt[1] + height*gt[2]
Ymax = gt[3] + width*gt[4] + height*gt[5]
extent = [Xmin,Xmax,min(Ymin,Ymax),max(Ymin,Ymax)]
cellHeight = math.fabs(gt[1])
cellWidth = math.fabs(gt[5])
# #准备进行处理
# aid='saga:rastervaluestopoints'
out_path = os.path.join(work_space,ras2pot_folder)
outFile=out_path+fileName[0:fileName.index('.')]+'.geojson'
# p = {
# 'GRIDS': [inFile],
# 'NODATA': True,
# 'POLYGONS': None,
# 'SHAPES': outFile,
# 'TYPE': 0}
# processing.run(aid, p)
print(qgis_path)
exe_path= qgis_path.split('python.exe')[0]+"saga_cmd.exe"
print(exe_path)
cammand_line=exe_path+" shapes_grid 3 -GRIDS {} -POLYGONS None -NODATA 1 -SHAPES {} -TYPE 0".format(inFile,outFile)# --flags[qr]
process_status=subprocess.call(cammand_line)
if 0!=process_status:
print("error ")
pro_aid = 'native:assignprojection'
proFile = work_space+ras2potpro_folder+fileName[0:fileName.index('.')]+'.geojson'
pro_p = {
'CRS': QgsCoordinateReferenceSystem('EPSG:4546'),
'INPUT': outFile,
'OUTPUT': proFile,
}
processing.run(pro_aid, pro_p)
return fileName[0:fileName.index('.')]+'.geojson', cellHeight, cellWidth, tif_crs, extent
def points2mesh(fileName,tif_crs):
aid='native:tinmeshcreation'
inFile=work_space+ras2potpro_folder+fileName
optFile=work_space+tinmesh_folder+fileName[0:fileName.index('.')]+'.file'
p={ 'CRS_OUTPUT' : QgsCoordinateReferenceSystem('EPSG:4546'), 'MESH_FORMAT' : 2, 'OUTPUT_MESH' : optFile,
'SOURCE_DATA' : [{'source': inFile,'type': 0,'attributeIndex': 3}] }
processing.run(aid,p)
return fileName[0:fileName.index('.')]+'.file'
def getpointFromMesh(fileName, tif_crs,extent):
inFile = work_space + tinmesh_folder + fileName
if not os.path.exists(inFile):
inFile = work_space + tinmesh_folder + fileName + '.ply'
optFile = 'TEMPORARY_OUTPUT'
aid = 'native:exportmeshongrid'
# tif_crs
p = {'CRS_OUTPUT' : QgsCoordinateReferenceSystem('EPSG:4546'), 'DATASET_GROUPS': [0],
'DATASET_TIME': {'type': 'static'},
'EXTENT':None,
'GRID_SPACING': 0.05,
'INPUT' : inFile, 'OUTPUT': optFile, 'VECTOR_OPTION': 0}
# 'EXTENT': '%.8f,%.8f,%.8f,%.8f [EPSG:4546]' % (extent[0],extent[1],extent[2],extent[3]),
res = processing.run(aid, p)
# Save as a shapefile
Fl_ou = fileName[0:fileName.index('.')] + '.geojson'
Fl_ou = work_space + potfrommesh_folder + Fl_ou
options = QgsVectorFileWriter.SaveVectorOptions()
options.driverName = "geojson"
QgsVectorFileWriter.writeAsVectorFormatV2(res['OUTPUT'], Fl_ou, QgsCoordinateTransformContext(), options)
# TODO turn string to realNumber
return fileName[0:fileName.index('.')] + '.geojson'
def transString2float(fileName):
inFile = work_space + potfrommesh_folder + fileName
opt = work_space + potfrommeshfloat_folder + fileName
optFile = open(opt, 'w')
# minx=9999999
# maxx=0
# miny=99999999
# maxy=0
with open(inFile) as jsonFile:
points = json.load(jsonFile)
features = points['features']
for feature in features:
properties = feature['properties']
Bed_Elevation = float(properties['Bed Elevation'])
properties['Bed Elevation'] = Bed_Elevation
#
# geometry=feature['geometry']
'''
coordinates = geometry['coordinates']
x = coordinates[0]
if minx>x:
minx =x
if maxx<x:
maxx=x
y= coordinates[1]
if miny > y:
miny=y
if maxy <y:
maxy=y
'''
json.dump(points, optFile)
optFile.close()
return fileName
def point2tif(fileName,cellHeight,cellWidth,tif_crs,extent):
inFile=work_space+potfrommeshfloat_folder+fileName
# print(inFile)
optFile=work_space+filltif_folder+fileName[0:fileName.index('.')]+'.tif'
# print(extent)
aid='gdal:rasterize'
p = {'BURN':0, 'DATA_TYPE':5,
'EXTENT':None,
'EXTRA': '', 'FIELD' : 'Bed Elevation', 'HEIGHT' : cellHeight,'INIT' : None,
'INPUT': inFile, 'INVERT' : False, 'NODATA' : 0, 'OPTIONS' : '', 'OUTPUT' : optFile,
'UNITS': 1, 'USE_Z' : False, 'WIDTH' : cellWidth}
# p= { 'BURN':0, 'DATA_TYPE':5,
# 'EXTENT':'544158.5540865903, 544161.2540865903, 4450451.895210848, 4450455.295210849 [EPSG:4546]',
# 'EXTRA' : '', 'FIELD' : 'Bed Elevation', 'HEIGHT' : cellHeight,
# 'INIT' : None, 'INPUT' : inFile, 'INVERT' : False, 'NODATA' : 0,
# 'OPTIONS' : '',
# 'OUTPUT' : optFile, 'UNITS' : 1, 'USE_Z' : False, 'WIDTH' : cellWidth}
processing.run(aid,p)
dataset=gdal.Open(optFile)
width = dataset.RasterXSize
height = dataset.RasterYSize
# print('out wid,hei:',width,height)
optFile2=work_space+'tif2/'+fileName[0:fileName.index('.')]+'.tif'
# print(optFile2)
pro_aid='gdal:warpreproject'
# proFile=optFile
# proFile=rootPath+'raster2points_pro/'+fileName[0:fileName.index('.')]+'.shp'
# pro_p={ 'CRS' : QgsCoordinateReferenceSystem('EPSG:4546'), 'INPUT' : outFile, 'OUTPUT' :proFile}
pro_p={'SOURCE_CRS' : QgsCoordinateReferenceSystem('EPSG:4546'),
'TARGET_CRS' : QgsCoordinateReferenceSystem(tif_crs),
'EXTENT': '%.8f,%.8f,%.8f,%.8f [EPSG:4546]' % (extent[0], extent[1], extent[2], extent[3]),
'RESAMPLING':2,'INPUT' : optFile, 'OUTPUT' :optFile2}
processing.run(pro_aid,pro_p)
return fileName
def processOneTif(fileName):
# 转矢量点,赋投影
processedFile, cellHeight, cellWidth, tif_crs, extent = raster2points(fileName)
meshFile = points2mesh(processedFile, tif_crs)
pointsFile = getpointFromMesh(meshFile, tif_crs, extent)
pointsFilefloat = transString2float(pointsFile)
point2tif(pointsFilefloat, cellHeight, cellWidth, tif_crs, extent)
def demFill():
# print("demFill")
# sys.stdout.flush()
single_files_list=os.listdir(work_space+singlehole_folder)
temp_file_lists = []
for file in single_files_list:
if os.path.splitext(file)[1] == '.tif' or os.path.splitext(file)[1] == '.TIF':
temp_file_lists.append(file)
total = len(temp_file_lists)
for index, file in enumerate(temp_file_lists):
sys.stdout.flush()
processOneTif(file)
print(10+float((index+1) / total * 79))
sys.stdout.flush()
def back2Whole():
txtList=[]
txt_dir=work_space+out_txt_folder
sub_list = os.listdir(txt_dir)
#先找txt,保存文件名到list中
for f in sub_list:
if os.path.splitext(f)[1] == '.txt' or os.path.splitext(f)[1] == '.TXT':
txtList.append(f)
txtNum=len(txtList)
if txtNum<=0:
# print("There is no holes to join")
print(-1.0)
sys.stdout.flush()
currentIndicator=0
for txtfile in txtList:
#根据文件名找原来的图像文件,拷贝
imgname=os.path.splitext(txtfile)[0]
demfile=imgname+'.tif'# txt文件名与dem文件名对应
txt_route= os.path.join(txt_dir, txtfile)# txt文件
dem_route = os.path.join(in_holedem, demfile)# 空缺的整图
proj, geotrans, dem_img, width, height = tifIO.ReadTif(dem_route) # c,h,w
img = dem_img.copy()
#读txt,找到对应的小矩形文件
file_handle=open(txt_route,mode="r")
lines=file_handle.readlines()
for i in lines:
s=i.split(',')
#['文件名,左上角列号,左上角行号,右下角列号,右下角行号\n']
subname=s[0]
lt_x=int(s[1])
lt_y=int(s[2])
rb_x=int(s[3])
rb_y=int(s[4])
# print(subname+" img's lt_x:%d,lt_y:%d,rb_x:%d,rb_y:%d" % (lt_x,lt_y,rb_x,rb_y))
# 读相应的文件,根据行列号替换
sub_route = os.path.join(work_space+filltif_folder, subname)
proj_sub, geotrans_sub, sub_img, w_sub, h_sub = tifIO.ReadTif(sub_route)
sp = sub_img.shape
img_w = sp[1]
img_h = sp[0]
# print(subname + " img's w:%d,h:%d" % (img_w,img_h))
# img[lt_y:rb_y,lt_x:rb_x]=sub_img
img[lt_y:lt_y+img_h,lt_x:lt_x+img_w]=sub_img
file_handle.close()
out_path = os.path.join(work_space,out_fulltif_folder)
if os.path.exists(out_path):
try:
shutil.rmtree(out_path)
time.sleep(2) # 防止删除操作未结束就运行mkdir()
except Exception as e:
print(-1.0)
sys.stdout.flush()
os.mkdir(out_path)
out_wholefile=imgname+'_whole.tif'
tifIO.writeTif(out_path+out_wholefile, proj, geotrans, img)
currentIndicator=currentIndicator+1
print(90.0+float(currentIndicator/txtNum)*10)
sys.stdout.flush()
if __name__ == '__main__':
print(0)
sys.stdout.flush()
in_demfolder = in_holedem
out_singlehole = os.path.join(work_space, singlehole_folder)
txt_folder = os.path.join(work_space, out_txt_folder)
# 执行寻找并保存单个空洞
findHoles(in_demfolder,out_singlehole,txt_folder)
print(2)
sys.stdout.flush()
confirmWorkspaceClear(work_space)
demFill()
back2Whole()

View File

@ -0,0 +1,68 @@
from osgeo import gdal
# from tqdm import tqdm
def ReadTif(tif_path):
dataset = gdal.Open(tif_path)
width = dataset.RasterXSize
height = dataset.RasterYSize
geotrans = list(dataset.GetGeoTransform()) # 仿射矩阵
proj = dataset.GetProjection() # 地图投影信息
data = dataset.ReadAsArray(0, 0, width, height) # 将数据写成数组,对应栅格矩阵
del dataset # 关闭对象文件dataset
return proj, geotrans, data, width, height
def writeTif(fileroute, im_proj, im_geotrans, im_data):
# 判断栅格数据的数据类型
"""
GDAL中的GDALDataType是一个枚举型其中的值为
GDT_Unknown : 未知数据类型
GDT_Byte : 8bit正整型 (C++中对应unsigned char)
GDT_UInt16 : 16bit正整型 (C++中对应 unsigned short)
GDT_Int16 : 16bit整型 (C++中对应 short short int)
GDT_UInt32 : 32bit 正整型 (C++中对应unsigned long)
GDT_Int32 : 32bit整型 (C++中对应int long long int)
GDT_Float32 : 32bit 浮点型 (C++中对应float)
GDT_Float64 : 64bit 浮点型 (C++中对应double)
GDT_CInt16 : 16bit复整型 (?)
GDT_CInt32 : 32bit复整型 (?)
GDT_CFloat32 : 32bit复浮点型 (?)
GDT_CFloat64 : 64bit复浮点型 (?)
"""
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
im_data = im_data.reshape(im_bands,im_height, im_width)
# 创建文件
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(fileroute, im_width,im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data[0]) # 写入数组数据
else:
#for i in range(im_bands):
for i in tqdm(range(im_bands)):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
def np2gdal(hwc): # hwc2chw
chw = hwc.swapaxes(2, 0).swapaxes(1, 2) # h,w,c to c,h,w
return chw
def gdal2np(chw): #chw2hwc
hwc = chw.swapaxes(1, 0).swapaxes(1, 2) # h,w,c
return hwc