from osgeo import gdal, osr, ogr import struct from Priority_Queue import PQueue from math import floor import os from pathlib import Path from PIL import Image import rasterio as rio from rasterio.plot import show # from Output import tnm_api_download_raster # Returns value underneath point on raster def cell_value(px, py, raster_band): struct_val = raster_band.ReadRaster(px, py, 1, 1, buf_type=gdal.GDT_Float32) return struct.unpack('f', struct_val)[0] # Finds the 8 values along with their coords around the xy values sent # TODO: trigger pulling next tile if a neighbor is outside of raster def adj_values(px, py, raster_band, neighbors, x_count, y_count, buffer=1): for x in range(1, buffer+1): for y in range(1, buffer+1): # if top left is off tile set to 0, push 0 as coord value if px-x > 0 and py - y > 0: struct_val = raster_band.ReadRaster(px - x, py - y, 1, 1, buf_type=gdal.GDT_Float32) top_left = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile top_left = -9999 return px-x, py-y neighbors.push((px-x, py-y), top_left) # Top Center if py - y > 0: struct_val = raster_band.ReadRaster(px, py - y, 1, 1, buf_type=gdal.GDT_Float32) top_center = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile top_center = -9999 return px, py-y neighbors.push((px, py-y), top_center) # Top Right if px + x <= x_count and py - y > 0: struct_val = raster_band.ReadRaster(px + x, py - y, 1, 1, buf_type=gdal.GDT_Float32) top_right = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile top_right = -9999 return px+x, py-y neighbors.push((px+x, py-y), top_right) # Center Left if px - x > 0: struct_val = raster_band.ReadRaster(px - x, py, 1, 1, buf_type=gdal.GDT_Float32) center_left = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile center_left = -9999 return px-x, py neighbors.push((px-x, py), center_left) # Center Right if px + x <= x_count: struct_val = raster_band.ReadRaster(px + x, py, 1, 1, buf_type=gdal.GDT_Float32) center_right = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile center_right = -9999 return px+x, py neighbors.push((px+x, py), center_right) # Bottom Left if px - x > 0 and py + y <= y_count: struct_val = raster_band.ReadRaster(px - x, py + y, 1, 1, buf_type=gdal.GDT_Float32) bottom_left = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile bottom_left = -9999 return px-x, py+y neighbors.push((px-x, py+y), bottom_left) # Bottom Center if py + y <= y_count: struct_val = raster_band.ReadRaster(px, py + y, 1, 1, buf_type=gdal.GDT_Float32) bottom_center = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile bottom_center = -9999 return px, py + y neighbors.push((px, py+y), bottom_center) # Bottom Right if px + x <= x_count and py + y <= y_count: struct_val = raster_band.ReadRaster(px + x, py + y, 1, 1, buf_type=gdal.GDT_Float32) bottom_right = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile bottom_right = -9999 return px+x, py+y neighbors.push((px+x, py+y), bottom_right) return None, None # Recursive search for higher points around a test value def check_neighbors(test_value, point, neighbors, rb, input_point, summit_found, buffer, x_count, y_count): while neighbors.length() != 0 and not summit_found: popped_x, popped_y, popped_value = neighbors.pop_point() if popped_x > x_count or popped_y > y_count or popped_x == 0 == popped_y: print("outside of file") if popped_value > test_value and not summit_found: # Move the point to the better summit candidate px, py = adj_values(popped_x, popped_y, rb, neighbors, x_count, y_count) # Get the 8 neighbors of the popped point if px and py: return px, py check_neighbors(popped_value, (popped_x, popped_y), neighbors, rb, input_point, summit_found, buffer, x_count, y_count) break if popped_value <= test_value and not summit_found: summit_found, possible_summit_x, possible_summit_y, pvalue, new_px, new_py \ = check_neighbors_within_buffer(point[0], point[1], buffer, rb, neighbors, test_value, x_count, y_count) if new_px and new_py: return new_px, new_py if summit_found: input_point.corrected_summit_x, input_point.corrected_summit_y, input_point.corrected_value \ = possible_summit_x, possible_summit_y, pvalue break else: check_neighbors(pvalue, (possible_summit_x, possible_summit_y), neighbors, rb, input_point, summit_found, buffer, x_count, y_count) return None, None def check_neighbors_within_buffer(px, py, buffer, rb, neighbors, test_value, x_count, y_count): new_px, new_py = adj_values(px, py, rb, neighbors, x_count, y_count, buffer) if new_px and new_py: return None, None, None, None, new_py, new_px popped_x, popped_y, popped_value = neighbors.pop_point() if popped_value <= test_value: return True, px, py, test_value, None, None if popped_value > test_value: return False, popped_x, popped_y, popped_value, None, None def get_raster_info(raster): # Get the raster band rb = raster.GetRasterBand(1) # Get cell counts y_count, x_count = rb.ReadAsArray().shape # Get the initial raster srs raster_prj = raster.GetProjection() raster_srs = osr.SpatialReference(wkt=raster_prj) # Fetches the coefficients for transform between pixel/line raster space, and projection coordinates space. gt_forward = raster.GetGeoTransform() # Get info from raster for later use # Why is the last number a negative one? I think this may be causing all the negative y values xmin, ymax, cell_size_x, cell_size_y = gt_forward[0], gt_forward[3], gt_forward[1], -gt_forward[5] # Invert that transform so we can convert points from map to pixel gt_reverse = gdal.InvGeoTransform(gt_forward) # Get the no data value no_data = rb.GetNoDataValue() # Return the information return rb, gt_reverse, gt_forward, xmin, ymax, cell_size_x, cell_size_y, raster_srs, x_count, y_count def build_vrt_from_dir(src_dir, dst_vrt, args=None): """ Build vrt from all raster files in input directory :param src_dir: directory to search for raster files :param dst_vrt: output VRT file :param args: buildvrt args :return: None """ # gdal cache to 2gb gdal.SetCacheMax(2 ** 31) # Set up gdal error handling # err = GdalErrorHandler() # handler = err.handler # gdal.PushErrorHandler(handler) # gdal.UseExceptions() files = list() if type(src_dir) == list: files = src_dir elif os.path.isdir(src_dir): if not os.path.exists(src_dir): raise Exception('Source directory does not exist!') else: for filename in os.listdir(src_dir): if filename.endswith('.tif') or filename.endswith('.vrt'): files.append(os.path.join(src_dir, filename)) if not dst_vrt.endswith('.vrt'): raise Exception('Invalid destination dataset: must be VRT format') if os.path.isfile(dst_vrt): os.remove(dst_vrt) dst_dir = os.path.dirname(dst_vrt) if not os.path.exists(dst_dir): os.mkdir(dst_dir) if args: gdal.BuildVRT(dst_vrt, files, options=args) else: gdal.BuildVRT(dst_vrt, files) return None def find_raster_summit(pt, buffer, in_srs, lock, added=0): # Priority queue for storing neighboring cells neighbors = PQueue() # Bool to keep track of when a summit is found summit_found = False # Define data data = pt.get_most_recent_data() vrt_name = 'vsimem/tmp.vrt' tif_list = [] srs = osr.SpatialReference() if len(data.location) > 0: # iterate through files for data_file in data.location: # TODO: check if warping to same SR is ok tmp_vrt = 'vsimem/' + Path(data_file).stem + '.vrt' try_count = 0 while try_count < 2: try: ds = gdal.Warp(tmp_vrt, data_file, dstSRS='EPSG:4269') # in_srs.GetAuthorityCode(None)) except: im = Image.open(data_file) im.save(data_file.replace(".tif", ".jpg")) im = Image.open(data_file.replace(".tif", ".jpg")) im.save(data_file) try_count += 1 if try_count: if data.type == "National Elevation Dataset (NED) 1/3 arc-second": pt.read_error_NED = True elif data.type == "Alaska IFSAR 5 meter DEM": pt.read_error_5M = True tif_list.append(data_file) break build_vrt_from_dir(tif_list, "orig_ver.vrt", args="-r bilinear -resolution lowest -overwrite") vrt_dataset = gdal.Open(vrt_name) # get ndv ndv = vrt_dataset.GetRasterBand(1).GetNoDataValue() # get raster info rb, gt_reverse, gt_forward, xmin, ymax, cell_size_x, cell_size_y, raster_srs, x_count, y_count = \ get_raster_info(vrt_dataset) # set cell size of dataset data.cell_size_x = cell_size_x data.cell_size_y = cell_size_y targetSR = raster_srs targetSR.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) trans_point = ogr.Geometry(ogr.wkbPoint) if pt.center_x and pt.center_y: trans_point.AddPoint(pt.center_x, pt.center_y) else: trans_point.AddPoint(pt.map_x, pt.map_y) trans_point.AssignSpatialReference(in_srs) trans_point.TransformTo(targetSR) # coords in pixel units px, py = gdal.ApplyGeoTransform(gt_reverse, trans_point.GetX(), trans_point.GetY()) # x pixel of feature px = floor(px) # y pixel of feature py = floor(py) # Handles point being outside of tile due to truncation from TNM if px < 0 or py < 0 or px > x_count or py > y_count: print("Point outside of pulled tile, attempting to pull correct tile.") # move extra pixel(s) to try and overcome truncating digits if px < 0: px = px - (8 * added) if py < 0: py = py - (8 * added) if px > x_count: px = px + (8 * added) if py > y_count: py = py + (8 * added) dx, dy = gdal.ApplyGeoTransform(gt_forward, px, py) if added >= 10: with lock: print("Could not find tile for {0} in resolution {1}".format(pt.name, data.type)) return None, None, None, None return dx, dy, raster_srs, added + 1 else: # Get the initial value below the GNIS point GNIS_value = cell_value(px, py, rb) if GNIS_value == ndv: with lock: print("Found NDV for ", pt.name) elif not pt.original_z: pt.original_z = GNIS_value pt.original_x, pt.original_y = to_latlon(gt_forward, px, py, raster_srs) # Get the initial 8 neighbors new_px, new_py = adj_values(px, py, rb, neighbors, x_count, y_count) if new_px or new_py: dx, dy = gdal.ApplyGeoTransform(gt_forward, new_px, new_py) return dx, dy, raster_srs, 0 # Check for higher values new_px, new_py = check_neighbors(GNIS_value, (px, py), neighbors, rb, pt, summit_found, buffer, x_count, y_count) if new_px or new_py: dx, dy = gdal.ApplyGeoTransform(gt_forward, new_px, new_py) return dx, dy, raster_srs, 0 # Get map coords of corrected summit if pt.corrected_summit_x and pt.corrected_summit_y: pt.map_corrected_summit_x, pt.map_corrected_summit_y = to_latlon(gt_forward, pt.corrected_summit_x, pt.corrected_summit_y, raster_srs) # pt.map_corrected_summit_x, pt.map_corrected_summit_y \ # = gdal.ApplyGeoTransform(gt_forward, pt.corrected_summit_x, pt.corrected_summit_y) else: with lock: print("Failed to find summit for: ", pt.name) if pt.corrected_value != ndv: # Don't set the point info if it was a NDV pt.set_info(xmin, ymax, in_srs, raster_srs, data.type) else: # If current try to fix bad location bug doesn't work maybe put the chunk of code down here data.bad_location = data.bad_location + tif_list data.location = [] dx, dy = gdal.ApplyGeoTransform(gt_forward, px, py) return dx, dy, raster_srs, 0 pass # TODO: Look at pulling the next tile in the list return None, None, None, None else: raise Exception("No tif found for ", data.name) def to_latlon(gt_forward, px, py, raster_srs): targetSR = osr.SpatialReference() targetSR.ImportFromEPSG(4326) # WGS84 coordTrans = osr.CoordinateTransformation(raster_srs, targetSR) point = ogr.Geometry(ogr.wkbPoint) x, y = gdal.ApplyGeoTransform(gt_forward, px, py) point.AddPoint(x, y) point.Transform(coordTrans) return point.GetY(), point.GetX()