import subprocess import xml.etree.ElementTree as ET from osgeo import gdal, osr, ogr import struct from Priority_Queue import PQueue from math import floor import os from pathlib import Path from PIL import Image import rasterio as rio from rasterio.merge import merge from rasterio.plot import show # File path to distribution folder in my environment Distribution = "/tmp/.x2go-areed-comeaux/media/disk/_cygdrive_X_/13/TIFF/current/" def cell_value(px, py, raster_band): """ Args: px: x coordinate of point in pixels py: y coordinate of point in pixels raster_band: the raster band Returns: The value underneath the given point on the raster tile """ struct_val = raster_band.ReadRaster(px, py, 1, 1, buf_type=gdal.GDT_Float32) return struct.unpack('f', struct_val)[0] # Finds the 8 values along with their coords around the xy values sent # TODO: trigger pulling next tile if a neighbor is outside of raster def adj_values(px, py, raster_band, neighbors, x_count, y_count, buffer=1): """ Args: px: x coordinate of point in pixels py: y coordinate of point in pixels raster_band: the raster band neighbors: a priprity queue x_count: y_count: buffer: Returns: Either returns coordinates in pixel form or nothing. In the case of returning nothing the neighbors priority queue will be full of the elevation values for the 8 adjacent pixels and their coordinates """ for x in range(1, buffer+1): for y in range(1, buffer+1): # if top left is off tile set to 0, push 0 as coord value if px-x > 0 and py - y > 0: struct_val = raster_band.ReadRaster(px - x, py - y, 1, 1, buf_type=gdal.GDT_Float32) top_left = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile top_left = -9999 return px-x, py-y neighbors.push((px-x, py-y), top_left) # Top Center if py - y > 0: struct_val = raster_band.ReadRaster(px, py - y, 1, 1, buf_type=gdal.GDT_Float32) top_center = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile top_center = -9999 return px, py-y neighbors.push((px, py-y), top_center) # Top Right if px + x <= x_count and py - y > 0: struct_val = raster_band.ReadRaster(px + x, py - y, 1, 1, buf_type=gdal.GDT_Float32) top_right = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile top_right = -9999 return px+x, py-y neighbors.push((px+x, py-y), top_right) # Center Left if px - x > 0: struct_val = raster_band.ReadRaster(px - x, py, 1, 1, buf_type=gdal.GDT_Float32) center_left = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile center_left = -9999 return px-x, py neighbors.push((px-x, py), center_left) # Center Right if px + x <= x_count: struct_val = raster_band.ReadRaster(px + x, py, 1, 1, buf_type=gdal.GDT_Float32) center_right = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile center_right = -9999 return px+x, py neighbors.push((px+x, py), center_right) # Bottom Left if px - x > 0 and py + y <= y_count: struct_val = raster_band.ReadRaster(px - x, py + y, 1, 1, buf_type=gdal.GDT_Float32) bottom_left = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile bottom_left = -9999 return px-x, py+y neighbors.push((px-x, py+y), bottom_left) # Bottom Center if py + y <= y_count: struct_val = raster_band.ReadRaster(px, py + y, 1, 1, buf_type=gdal.GDT_Float32) bottom_center = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile bottom_center = -9999 return px, py + y neighbors.push((px, py+y), bottom_center) # Bottom Right if px + x <= x_count and py + y <= y_count: struct_val = raster_band.ReadRaster(px + x, py + y, 1, 1, buf_type=gdal.GDT_Float32) bottom_right = struct.unpack('f', struct_val)[0] else: # TODO: Pull neighbor tile and join with input tile bottom_right = -9999 return px+x, py+y neighbors.push((px+x, py+y), bottom_right) return None, None # Recursive search for higher points around a test value def check_neighbors(test_value, point, neighbors, rb, input_point, summit_found, buffer, x_count, y_count): """ Args: test_value: the current highest value point: the point neighbors: priority queue holding the values and coordinates for the 8 adjacent pixels rb: raster band input_point: the point that the function is trying to find the highest value for; is custom type Point summit_found: flag to indicate if the summit has been found or not buffer: x_count: y_count: Returns: None, fills in the input_point with information about the summit that was found """ while neighbors.length() != 0 and not summit_found: popped_x, popped_y, popped_value = neighbors.pop_point() if popped_x > x_count or popped_y > y_count or popped_x == 0 == popped_y: print("outside of file") if popped_value > test_value and not summit_found: # Move the point to the better summit candidate px, py = adj_values(popped_x, popped_y, rb, neighbors, x_count, y_count) # Get the 8 neighbors of the popped point if px and py: return px, py check_neighbors(popped_value, (popped_x, popped_y), neighbors, rb, input_point, summit_found, buffer, x_count, y_count) break if popped_value <= test_value and not summit_found: summit_found, possible_summit_x, possible_summit_y, pvalue, new_px, new_py \ = check_neighbors_within_buffer(point[0], point[1], buffer, rb, neighbors, test_value, x_count, y_count) if new_px and new_py: return new_px, new_py if summit_found: input_point.corrected_summit_x, input_point.corrected_summit_y, input_point.corrected_value \ = possible_summit_x, possible_summit_y, pvalue break else: check_neighbors(pvalue, (possible_summit_x, possible_summit_y), neighbors, rb, input_point, summit_found, buffer, x_count, y_count) return None, None def check_neighbors_within_buffer(px, py, buffer, rb, neighbors, test_value, x_count, y_count): """ Args: px: X coordinate of current summit point in pixels py: Y coordinate of current summit point in pixels buffer: rb: Raster band neighbors: Priority queue with values and coordinates of the adjacent pixels test_value: Current highest elevation value, gets replaced if there is a higher value to be found x_count: y_count: """ new_px, new_py = adj_values(px, py, rb, neighbors, x_count, y_count, buffer) if new_px and new_py: return None, None, None, None, new_py, new_px popped_x, popped_y, popped_value = neighbors.pop_point() if popped_value <= test_value: return True, px, py, test_value, None, None if popped_value > test_value: return False, popped_x, popped_y, popped_value, None, None def get_raster_info(raster): """ Args: raster: the raster tile to grab information from Returns: returns the raster band, transformations to go to and from points on maps to pixels, the max x and y values, the """ # Get the raster band rb = raster.GetRasterBand(1) # Get cell counts -> as in the max height and width of the tile? y_count = raster.RasterYSize x_count = raster.RasterXSize #y_count, x_count = rb.ReadAsArray().shape # Get the initial raster srs raster_prj = raster.GetProjection() raster_srs = osr.SpatialReference(wkt=raster_prj) # Fetches the coefficients for transform between pixel/line raster space, and projection coordinates space. gt_forward = raster.GetGeoTransform() # Get info from raster for later use xmin, ymax, cell_size_x, cell_size_y = gt_forward[0], gt_forward[3], gt_forward[1], -gt_forward[5] # Invert that transform, so we can convert points from map to pixel gt_reverse = gdal.InvGeoTransform(gt_forward) # Return the information return rb, gt_reverse, gt_forward, xmin, ymax, cell_size_x, cell_size_y, raster_srs, x_count, y_count def build_vrt_from_dir(src_dir, dst_vrt, args=None): """ Args: src_dir: list of raster datasets dst_vrt: output VRT file args: buildvrt args Returns: None, builds vrt from all raster files in input directory """ # gdal cache to 2gb gdal.SetCacheMax(2 ** 31) # Set up gdal error handling # err = GdalErrorHandler() # handler = err.handler # gdal.PushErrorHandler(handler) # gdal.UseExceptions() if not dst_vrt.endswith('.vrt'): raise Exception('Invalid destination dataset: must be VRT format') if os.path.isfile(dst_vrt): os.remove(dst_vrt) dst_dir = os.path.dirname(dst_vrt) if not os.path.exists(dst_dir): os.mkdir(dst_dir) src_dir.reverse() if args: gdal.BuildVRT(dst_vrt, src_dir, options=args) else: gdal.BuildVRT(dst_vrt, src_dir) raster = rio.open(dst_vrt) show(raster) return None # if read fails it needs to be able to use exception as data file name, load up that tiff, then redo everything (raster # info, cell size, etc.) def find_raster_summit(pt, buffer, in_srs, lock, added=0): """ Args: pt: current point to find summit for; type is Point buffer: in_srs: input spatial reference lock: manager for multiprocessing added: a counter to keep track of how many times there's been an attempt to find the summit Returns: either none or the coordinates for the next location to check """ tries = 0 tiff_name = None while tries <= 1: try: # Priority queue for storing neighboring cells neighbors = PQueue() # Bool to keep track of when a summit is found summit_found = False data = pt.get_most_recent_data() if tries: tiff_name = tiff_name.split('/')[-1] tiff_coords = tiff_name.split('_')[-1] + '/' tiff_coords = tiff_coords.replace('.tif', '') location = Distribution + tiff_coords + tiff_name # if the VRT has a read error try to read from distribution else: location = data.location # else use VRT vrt_dataset = gdal.Open(location) # get ndv ndv = vrt_dataset.GetRasterBand(1).GetNoDataValue() # get raster info rb, gt_reverse, gt_forward, xmin, ymax, cell_size_x, cell_size_y, raster_srs, x_count, y_count = \ get_raster_info(vrt_dataset) # set cell size of dataset data.cell_size_x = cell_size_x data.cell_size_y = cell_size_y targetSR = raster_srs targetSR.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) trans_point = ogr.Geometry(ogr.wkbPoint) if pt.center_x and pt.center_y: trans_point.AddPoint(pt.center_x, pt.center_y) else: trans_point.AddPoint(pt.map_x, pt.map_y) trans_point.AssignSpatialReference(in_srs) # In order for this transform to work it was necessary to work around the ssl certificate. Originally # it was trying to download things from proj-data and running into issues with ssl certificate. To fix, # proj-data was installed in the environment. trans_point.TransformTo(targetSR) # coords in pixel units px, py = gdal.ApplyGeoTransform(gt_reverse, trans_point.GetX(), trans_point.GetY()) # x pixel of feature px = floor(px) # y pixel of feature py = floor(py) # Handles point being outside of tile due to truncation from TNM if px < 0 or py < 0 or px > x_count or py > y_count: print("Point outside of pulled tile, attempting to pull correct tile.") # move extra pixel(s) to try and overcome truncating digits if px < 0: px = px - (8 * added) if py < 0: py = py - (8 * added) if px > x_count: px = px + (8 * added) if py > y_count: py = py + (8 * added) dx, dy = gdal.ApplyGeoTransform(gt_forward, px, py) if added >= 10: with lock: print("Could not find tile for {0} in resolution {1}".format(pt.name, data.type)) return None, None, None, None return dx, dy, raster_srs, added + 1 else: # Get the initial value below the GNIS point GNIS_value = cell_value(px, py, rb) if GNIS_value and GNIS_value == ndv: with lock: print("Found NDV for ", pt.name) elif not pt.original_z: pt.original_z = GNIS_value pt.original_x, pt.original_y = to_latlon(gt_forward, px, py, raster_srs) # Get the initial 8 neighbors new_px, new_py = adj_values(px, py, rb, neighbors, x_count, y_count) if new_px or new_py: dx, dy = gdal.ApplyGeoTransform(gt_forward, new_px, new_py) return dx, dy, raster_srs, 0 # Check for higher values new_px, new_py = check_neighbors(GNIS_value, (px, py), neighbors, rb, pt, summit_found, buffer, x_count, y_count) if new_px or new_py: dx, dy = gdal.ApplyGeoTransform(gt_forward, new_px, new_py) return dx, dy, raster_srs, 0 # Get map coords of corrected summit if pt.corrected_summit_x and pt.corrected_summit_y: pt.map_corrected_summit_x, pt.map_corrected_summit_y = to_latlon(gt_forward, pt.corrected_summit_x, pt.corrected_summit_y, raster_srs) else: with lock: print("Failed to find summit for: ", pt.name) if pt.corrected_value != ndv: # Don't set the point info if it was a NDV pt.set_info(xmin, ymax, in_srs, raster_srs, data.type) else: # If current try to fix bad location bug doesn't work maybe put the chunk of code down here data.bad_location.append(data.name) data.location = [] dx, dy = gdal.ApplyGeoTransform(gt_forward, px, py) return dx, dy, raster_srs, 0 pass # TODO: Look at pulling the next tile in the list return None, None, None, None except Exception as e: print(e) tries += 1 tiff_name = e.args[0].split(',')[0] def to_latlon(gt_forward, px, py, raster_srs): targetSR = osr.SpatialReference() targetSR.ImportFromEPSG(4326) # WGS84 coordTrans = osr.CoordinateTransformation(raster_srs, targetSR) point = ogr.Geometry(ogr.wkbPoint) x, y = gdal.ApplyGeoTransform(gt_forward, px, py) point.AddPoint(x, y) point.Transform(coordTrans) return point.GetY(), point.GetX()