"""
 Tool Name:  Download3DEP
 Description: Downloads 3DEP data by HUC code using the National Map API or the Alaska.gov website.
 Author: Patrick Longley (plongley@usgs.gov)
 Created: 08/11/2020
 Language: Written in python3 (arcpro). Modified to also work in python2 (arcmap).
 History:
"""

# IMPORTS
import urllib
import os
import shutil
import arcpy
from zipfile import ZipFile
import json
import sys
import re
from contextlib import closing
import wbd_f
import wbd_params
import download_wbd
from WhiteboxTools_win_amd64.WBT.whitebox_tools import WhiteboxTools

#PYTHON 2/3 issues
python_version = sys.version_info.major
AK_IFSAR = 'Alaska IFSAR 5 meter DEM'
YEAR = 'year'
TILE = 'tile'
class Download3DEP(object):
    """
    Downloads 3DEP data by HUC code using the National Map API and the Alaska.gov website.

    Args:
        huc (str): HUC code (more than one can be entered.)
        out_folder (str):  Output folder where data will be saved.
        spatial_reference (ESRI spatial reference object): Spatial reference data will be saved in.
        dataset_3dep (str): 3DEP dataset that will be downloaded.
        buffer_dist (str): Buffer distance for DEM.  Conditional statemetn/raster algebra to remove no data (-10,000?)

    Outputs:
        returns: DEM (raster as .tif)
        output parameter: None

    """

    def __init__(self):
        """
        Initialize variables
        """
        self.label       = "Download 3DEP data"
        self.description = "This tool downloads 3DEP data using the national map api. " + \
                           "The extent is specified using using a HUC code and a WBD feature dataset."
        self.callfrom_pyt = True
        self.category = 'Download data'

    def getParameterInfo(self):
        """
        Define the parameters for use in arcmap/pro.
        """
        if python_version == 3:
            huc = wbd_params.huc_py3
        elif python_version == 2:
            huc = wbd_params.huc_py2
        params = [
            huc,
            wbd_params.out_folder,
            wbd_params.spatial_reference,
            wbd_params.dataset_3dep,
            wbd_params.buffer_dist,
            wbd_params.contour_spacing,
        ]
        return params

    def updateMessages(self, params):
        """
        Modify the messages created by internal validation for each tool
        parameter.This method is called after internal validation.
        """
        if params[0].value and params[0].altered and params[0].value!='':
            huclist = params[0].valueAsText.replace(" ", "").split(';')
            c1 = wbd_f.check_hucre(huclist)
            if not c1:
                message = 'Invalid HUC code(s).'
                params[0].setErrorMessage(message)

    def updateParameters(self, params):
        """
        Modify the values and properties of parameters before internal
        validation is performed.  This method is called whenever a parmater
        has been changed.
        """
        if not params[0].value or params[0].value == '':
            params[3].filter.list = wbd_params.dataset_3dep.filter.list
        elif params[0].altered:
            huclist = params[0].valueAsText.replace(" ", "").split(';')
            huc2s = [x[:2] for x in huclist]
            params[3].filter.list = wbd_f.modify_3depchoices(huc2s, params[3].filter.list)

    def get_wbd(self):
        """
        Checks if the correct WBD data exists in the output folder.
        Downloads wbd data if needed.
        """
        huc2 = self.huc[:2]
        # where to find geodatabase
        if self.out_folder.endswith('_prep'):
            wbd_folder = os.path.dirname(self.out_folder)
        else:
            wbd_folder = self.out_folder
        self.wbd_gdb = os.path.join(wbd_folder, 'WBD_{}_HU2.gdb'.format(huc2))
        self.wbd_fc = os.path.join(self.wbd_gdb, 'WBD', 'WBDHU' + self.hudigit)
        if not arcpy.Exists(self.wbd_fc):
            downloadwbd = download_wbd.DownloadWBD()
            params = (huc2, self.out_folder, None)
            downloadwbd.callfrom_pyt = False
            downloadwbd.execute(params, None)

    def create_wbd_fl(self):
        "Creates WBD feature layer."
        # create buffer
        hufield = 'huc' + self.hudigit
        where = """ {} = {} """.format(
                        arcpy.AddFieldDelimiters(self.wbd_fc, hufield),
                        "'" + self.huc + "'"
                        )
        with arcpy.EnvManager(overwriteOutput=True):
            self.wbd_fl = arcpy.MakeFeatureLayer_management(self.wbd_fc, "huc_fl", where_clause=where)

    def create_buffer(self):
        """
        Creates a buffer feature class that is used to get the extent.
        Returns the extent in a format that can be used by the
        national map api.
        """
        self.create_wbd_fl()
        if int(arcpy.GetCount_management(self.wbd_fl)[0]) == 0:
            return None
        if self.buffer_dist:
            buffer = os.path.join(self.wbd_gdb, ''.join(['hu', self.huc, '_buf_', ''.join(self.buffer_dist.split())]))
            if not arcpy.Exists(buffer):
                arcpy.Buffer_analysis(self.wbd_fl,
                                    buffer,
                                    self.buffer_dist)
            # get extent of buffer (formatted for api)
            extent = arcpy.Describe(buffer).extent
        else:
            buffer = self.wbd_fl
            for row in arcpy.da.SearchCursor(self.wbd_fl, ['SHAPE@']):
                extent = row[0].extent
        if self.dataset != AK_IFSAR:
            self.buffer_extent = ','.join([str(extent.XMin), str(extent.YMin),
                                        str(extent.XMax), str(extent.YMax)])
        else:
            self.buffer_extent = extent.polygon.JSON
        return buffer

    def downloadurls_nationalmap(self):
        """
        Uses the national map api to get a list of download urls
        for the given extent/3DEP data type.
        """
        url = r'https://tnmaccess.nationalmap.gov/api/v1/products?'
        url_data = {
            'prodFormats': 'GeoTIFF',
            'prodExtents': '1 x 1 degree',
            'bbox': self.buffer_extent,
            'datasets': self.dataset,
        }
        url = ''.join([url, urllib.parse.urlencode(url_data)])
        req = urllib.request.Request(url)
        download_urls = []
        with closing(urllib.request.urlopen(req)) as response:
            html = response.read()
            json_obj = json.loads(html)
            for item in json_obj['items']:
                # TODO What if JSON format changes?
                # download_urls.append((item['urls']['TIFF_download'], item['metaUrl']))
                download_urls.append((item['downloadURL'], item['metaUrl']))
        return download_urls

    def download_nationalmap(self):
        """Downloads 3DEP data for each download url using the national map api."""
        download_urls = self.downloadurls_nationalmap()
        for u in download_urls:
            tile = os.path.join(self.dem_folder, u[0].split('/')[-1])
            metadata = os.path.join(self.dem_folder, u[0].split('/')[-1].strip('.tif') + '.html')
            urllib.request.urlretrieve(u[0], tile)
            urllib.request.urlretrieve(u[1], metadata)

    def downloadurls_ifsar(self):
        """
        Uses the AK data portal to get a download url for ifsar data.
        """
        url = r'https://elevation.alaska.gov/download?'
        extent = json.loads(self.buffer_extent)
        extent = extent.pop('rings')
        ifsarid_dtm = 152
        geojson = {
            "type": "Polygon",
            "coordinates": extent
        }
        url_data = {"geojson": geojson, "ids": ifsarid_dtm}
        url = ''.join([url, urllib.parse.urlencode(url_data)])
        return url

    def download_ifsar(self):
        """Downloads ifsar data for each download url from the AK state website."""
        url = self.downloadurls_ifsar()
        zip_file = os.path.join(self.dem_folder, self.huc + '_tiles.zip')
        urllib.request.urlretrieve(url, zip_file)
        with ZipFile(zip_file, 'r') as zip_obj:
            zip_obj.extractall(self.dem_folder)
        os.remove(zip_file)
        for dirpath, subdirs, files in os.walk(self.dem_folder):
            for f in files:
                with ZipFile(os.path.join(dirpath, f), 'r') as zip_obj:
                    zip_obj.extractall(self.dem_folder)

    def get_dates(self, metadata_file):
        if self.dataset == AK_IFSAR:
            STARTDATE_REGEX = r'<dt><em>Beginning_Date:</em>\s+([0-9]{4})[0-9]{4}</dt>'
            ENDDATE_REGEX = r'<dt><em>Ending_Date:</em>\s+([0-9]{4})[0-9]{4}</dt>'
        else:
            STARTDATE_REGEX = r'<dt>Start Date</dt>\s*<dd>([0-9]{4})-[0-9]{2}-[0-9]{2}</dd>'
            ENDDATE_REGEX = r'<dt>End Date</dt>\s*<dd>([0-9]{4})-[0-9]{2}-[0-9]{2}</dd>'
        with open(metadata_file, 'r') as f:
            text = f.read()
        startdate = re.search(STARTDATE_REGEX, text)
        enddate = re.search(ENDDATE_REGEX, text)
        date = [d[1] for d in [startdate, enddate] if d]
        date.sort()
        return ''.join(set(date))

    def create_extentpolygon(self, raster, year):
        tile_name = os.path.splitext(os.path.basename(raster))[0]
        polygon_fpath = os.path.join(self.dem_folder, tile_name)
        with arcpy.EnvManager(overwriteOutput=True):
            extent = arcpy.ddd.RasterDomain(raster, polygon_fpath, "POLYGON")
        arcpy.AddField_management(extent, field_name=YEAR, field_type='TEXT')
        arcpy.AddField_management(extent, field_name=TILE, field_type='TEXT')
        arcpy.CalculateField_management(extent, TILE, "'{}'".format(tile_name), 'PYTHON')
        arcpy.CalculateField_management(extent, YEAR, "'{}'".format(year), 'PYTHON')
        return extent

    def get_years(self):
        tiles = []
        extents = []
        for raster in arcpy.ListRasters('*', 'TIF'):
            metadata = os.path.join(self.dem_folder, os.path.splitext(raster)[0] + '.html')
            try:
                year = self.get_dates(metadata)
            except FileNotFoundError:
                year = ''
            extent_polygon = self.create_extentpolygon(raster, year)
            extents.append(extent_polygon)
            tiles.append((raster, year))
        return tiles, extents

    def merge_extentfcs(self, extents):
            # merge extent polygons
            with arcpy.EnvManager(overwriteOutput = True):
                arcpy.Merge_management(extents, 'extent_' + self.huc)
            for fc in extents:
                arcpy.Delete_management(fc)

    def prep_tiles(self, tiles):
        # sort tiles (no year, then oldest to newest)
        tiles.sort(key = lambda x: x[1])
        #remove no data values from tiles
        tiles_con = [arcpy.sa.Con(arcpy.Raster(t[0]) != -10000, t[0]) for t in tiles]
        return tiles_con

    def delete_files(self, tiles):
        for tile in tiles:
            try:
                arcpy.Delete_management(tile[0])
            except:
                pass
        for f in os.listdir(self.dem_folder):
            if not f.startswith('dem') and not f.endswith('.html') and not f.startswith('extent'):
                fpath = os.path.join(self.dem_folder, f)
                if os.path.isdir(fpath):
                    shutil.rmtree(fpath)
                elif os.path.isfile(fpath):
                    try:
                        os.remove(fpath)
                    except:
                        pass

    def mosaic(self, tiles):
        """This function mosaics tiles together.  The tiles should already have no data values removed.  The
            tiles should also be ordered so that tiles with no year come first, then oldest to newest.
            Since by default the arcpy mosaic function will use the last tile, this leads to newer data overwriting
            older data.

        Args:
            tiles (list): list of tiles that have the no data value removed

        Returns:
            mosaic (result object): Mosaic of the inutted tiles.
        """
        mosaic = arcpy.MosaicToNewRaster_management(
            tiles,
            self.dem_folder,
            'mosaic.tif',
            number_of_bands=1,
            pixel_type='32_BIT_FLOAT',
        )
        for tile in tiles:
            arcpy.Delete_management(tile)
        return mosaic

    def create_contours(self, dem):
        """
        Creates contours.
        """
        self.contours_fc = os.path.join(self.dem_folder, 'contours{}m_{}.shp'.format(self.contour_dist, self.huc))
        if not arcpy.Exists(self.contours_fc):
            try:
                arcpy.Contour_3d(dem,
                                self.contours_fc,
                                self.contour_dist)
            except arcpy.ExecuteError:
                pass

    def execute(self, params, messages):
        """
        Loops through HUCS and downloads 3DEP data for each HUC.
        """
        whitebox_tools = WhiteboxTools()
        # parameters
        if self.callfrom_pyt:
            self.hucs = params[0].valueAsText
            self.out_folder = params[1].valueAsText
            self.sr = params[2].valueAsText
            self.dataset = params[3].valueAsText
            self.buffer_dist = params[4].valueAsText
            self.contour_dist = params[5].valueAsText
        else:
            self.hucs = params[0]
            self.out_folder = params[1]
            self.sr = params[2]
            self.dataset = params[3]
            self.buffer_dist = params[4]
            self.contour_dist = params[5]
        self.hucs = self.hucs.replace(" ", "").split(';')

        for huc in self.hucs:
            self.huc = huc
            self.hudigit = str(len(self.huc))
            self.get_wbd()
            self.dem_folder = os.path.join(self.out_folder, 'DEM_' + self.huc)
            self.dem = os.path.join(self.dem_folder, 'dem_' + self.huc + '.tif')
            self.hillshade = os.path.join(self.dem_folder, 'hillshade_' + self.huc + '.tif')
            if not os.path.exists(self.dem_folder):
                os.mkdir(self.dem_folder)
                buffer = self.create_buffer()
                if not buffer:
                    arcpy.AddWarning('No results found for {}. Data NOT downloaded.'.format(self.huc))
                    break
                if self.dataset == AK_IFSAR:
                    self.download_ifsar()
                else:
                    self.download_nationalmap()
                with arcpy.EnvManager(workspace=self.dem_folder):
                    tiles, extents = self.get_years()
                    self.merge_extentfcs(extents)
                    tiles_prep = self.prep_tiles(tiles)
                    mosaic = self.mosaic(tiles_prep)
                    if self.sr:
                        clip = arcpy.Clip_management(
                            mosaic[0],
                            "",
                            'clip.tif',
                            in_template_dataset=buffer,
                            clipping_geometry='ClippingGeometry'
                        )
                        dem = arcpy.ProjectRaster_management(
                            clip[0],
                            self.dem,
                            out_coor_system=self.sr
                        )
                    else:
                        dem = arcpy.Clip_management(
                            mosaic[0],
                            "",
                            self.dem,
                            in_template_dataset=buffer,
                            clipping_geometry='ClippingGeometry')
                    self.delete_files(tiles)
                    whitebox_tools.multidirectional_hillshade(self.dem, self.hillshade)
                    if self.contour_dist:
                        self.contour_dist = str(self.contour_dist)
                        self.create_contours(dem)
            else:
                arcpy.AddWarning('{} already exists, data not downloaded.'.format(self.dem_folder))
        return self.dem


if __name__ == '__main__':
    """
    Execute as standalone script.
    """
    hucs = "1908030501"
    output_folder = r"C:\Users\plongley\Desktop\tooltest_041221"
    sr = '4269'
    dataset_3dep = "National Elevation Dataset (NED) 1 arc-second"
    buffer_dist = '1000 Meters'
    contour_spacing = 10
    params = (hucs, output_folder, sr, dataset_3dep, buffer_dist, contour_spacing)
    download_3dep = Download3DEP()
    download_3dep.callfrom_pyt = False
    download_3dep.execute(params, None)
