from pathlib import Path
from typing import Dict, Optional
import sys
from time import perf_counter
from datetime import timedelta
import shutil

import arcpy.sa
import arcpy

from base_classes import Tool
from utils import print_header, isolate_env, UseExtension

FIELD_AREA_SQM = "area_sqm"
FIELD_REVIEW = "Review"
DIVIDER = 80 * "-"

class WalledFill(Tool):
    def __init__(self):
        """Define the tool (tool name is the name of the class)."""
        self.label = "G3 Walled Fill"
        self.description = (
            "Walls DEM at boundary lines and finds where flow is impeded or altered"
        )
        self.category = "Geometry checks"
        self.canRunInBackground = False

    def getParameterInfo(self):
        """Define parameter definitions"""
        output_gdb = arcpy.Parameter(
            displayName="Output geodatabase",
            name="output_gdb",
            datatype="DEWorkspace",
            parameterType="Required",
            direction="Input",
        )
        output_gdb.filter.list = ["Local Database"]

        dem_in = arcpy.Parameter(
            displayName="DEM",
            name="dem_in",
            datatype="DERasterDataset",
            parameterType="Required",
            direction="Input",
        )

        da_bnd_line = arcpy.Parameter(
            displayName="Drainage area boundary lines for wall generation",
            name="da_bnd_line",
            datatype="DEFeatureClass",
            parameterType="Required",
            direction="Input",
        )
        da_bnd_line.filter.list = ["Polyline"]

        flowline = arcpy.Parameter(
            displayName="Flowline features for breaching walls",
            name="flowline",
            datatype="DEFeatureClass",
            parameterType="Required",
            direction="Input",
        )
        flowline.filter.list = ["Polyline"]

        min_depth = arcpy.Parameter(
            displayName="Minimum fill depth (m)",
            name="min_depth",
            datatype="GPDouble",
            parameterType="Required",
            direction="Input",
        )
        min_depth.value = 0.0

        min_area = arcpy.Parameter(
            displayName="Minimum polygon area (sq. m)",
            name="min_area",
            datatype="GPDouble",
            parameterType="Required",
            direction="Input",
        )
        min_area.value = 0.0

        wall_height = arcpy.Parameter(
            displayName="Wall height (m)",
            name="wall_height",
            datatype="GPDouble",
            parameterType="Required",
            direction="Input",
        )
        wall_height.value = 5.0

        max_depth = arcpy.Parameter(
            displayName="Maximum depth to fill (m)",
            name="max_depth",
            datatype="GPDouble",
            parameterType="Required",
            direction="Input",
        )
        max_depth.value = 50.0

        cluster_dist = arcpy.Parameter(
            displayName="Cluster distance (m)",
            name="cluster_dist",
            datatype="GPDouble",
            parameterType="Required",
            direction="Input",
        )
        cluster_dist.value = 30.0

        dist_to_breach = arcpy.Parameter(
            displayName="Minimum distance (m) from wall breaches",
            name="dist_to_breach",
            datatype="GPDouble",
            parameterType="Required",
            direction="Input",
        )
        dist_to_breach.value = 30.0

        distance_from_boundary = arcpy.Parameter(
            displayName="Maximum distance (m) from drainage area boundary lines",
            name="distance_from_boundary",
            datatype="GPDouble",
            parameterType="Optional",
            direction="Input",
        )

        return [
            output_gdb,
            da_bnd_line,
            flowline,
            dem_in,
            wall_height,
            max_depth,
            min_depth,
            min_area,
            cluster_dist,
            dist_to_breach,
            distance_from_boundary,
        ]

    def execute(self, parameters, messages):
        print_header(self)
        lookup: Dict[str, arcpy.Parameter] = {p.name: p for p in parameters}
        dist = lookup["distance_from_boundary"].value
        main(
            Path(lookup["output_gdb"].valueAsText),
            Path(lookup["da_bnd_line"].valueAsText),
            Path(lookup["flowline"].valueAsText),
            Path(lookup["dem_in"].valueAsText),
            float(lookup["wall_height"].value),
            float(lookup["min_depth"].value),
            float(lookup["max_depth"].value),
            float(lookup["min_area"].value),
            float(lookup["cluster_dist"].value),
            float(lookup["dist_to_breach"].value),
            float(dist) if dist is not None else None,
        )
        return


def cluster_polys(
    in_poly_path: Path, out_poly_path: Path, temp_gdb_path: Path, distance: float
):
    temp_agg_polys = str(temp_gdb_path.joinpath(f"{in_poly_path.stem}_agg_polys"))
    temp_agg_table = str(temp_gdb_path.joinpath(f"{in_poly_path.stem}_agg_table"))
    # Create temp aggregated polygons
    arcpy.AggregatePolygons_cartography(
        str(in_poly_path), temp_agg_polys, distance, out_table=temp_agg_table
    )
    # Join aggregated poly FID to the original polys
    arcpy.JoinField_management(
        str(in_poly_path), "OBJECTID", temp_agg_table, "INPUT_FID", "OUTPUT_FID"
    )
    # Dissolve on the aggregated FID
    arcpy.PairwiseDissolve_analysis(
        str(in_poly_path), str(out_poly_path), dissolve_field="OUTPUT_FID"
    )
    # Delete temp data
    arcpy.Delete_management(temp_agg_polys)
    arcpy.Delete_management(temp_agg_table)


def select_by_size_and_distance(
    lyr: str,
    bnd_lines: str,
    breach_pts: str,
    min_area: float,
    dist_to_breach: float,
    distance_from_boundary: Optional[float],
):
    # Select all
    arcpy.SelectLayerByAttribute_management(lyr, "CLEAR_SELECTION")
    arcpy.SelectLayerByAttribute_management(lyr, "SWITCH_SELECTION")
    # If requested, select only polys larger than min_area
    if min_area:
        arcpy.SelectLayerByAttribute_management(
            lyr, "SUBSET_SELECTION", f"{FIELD_AREA_SQM} >= {min_area}"
        )
    # If requested, filter out polys near breaches
    if dist_to_breach:
        arcpy.SelectLayerByLocation_management(
            lyr,
            "WITHIN_A_DISTANCE",
            breach_pts,
            dist_to_breach,
            "SUBSET_SELECTION",
            "INVERT",
        )

    # If requested, select only polys within distance from DA boundary lines
    if distance_from_boundary is not None:
        arcpy.SelectLayerByLocation_management(
            lyr,
            "WITHIN_A_DISTANCE",
            bnd_lines,
            distance_from_boundary,
            "SUBSET_SELECTION",
        )


@isolate_env
def main(
    out_gdb: Path,
    da_line_path: Path,
    flowline_path: Path,
    dem_path: Path,
    wall_height: float,
    min_depth: float,
    max_depth: float,
    min_area: float,
    cluster_dist: float,
    dist_to_breach: float,
    distance_from_boundary: Optional[float],
):
    start_all = perf_counter()
    # Set environment variables
    folder_path = out_gdb.parent
    arcpy.env.workspace = str(folder_path)
    if not arcpy.Exists(str(out_gdb)):
        arcpy.CreateFileGDB_management(str(folder_path), out_gdb.name)
    temp_folder = folder_path.joinpath("temp")
    temp_folder.mkdir(exist_ok=True)
    temp_gdb = temp_folder.joinpath("temp.gdb")
    if not arcpy.Exists(str(temp_gdb)):
        arcpy.CreateFileGDB_management(str(temp_folder), temp_gdb.name)
    arcpy.env.scratchWorkspace = str(temp_folder)
    arcpy.env.snapRaster = str(dem_path)
    arcpy.env.cellSize = str(dem_path)
    arcpy.env.extent = str(dem_path)
    arcpy.env.overwriteOutput = True
    arcpy.env.parallelProcessingFactor = "100%"

    # Build paths
    da_line = str(da_line_path)
    flowline = str(flowline_path)
    out_breaching_lines = str(out_gdb.joinpath("G3_breach_lines"))
    out_breach_points = str(out_gdb.joinpath("G3_breach_points"))
    out_fill_polys = str(out_gdb.joinpath("G3_fill_areas"))
    out_fd_polys = str(out_gdb.joinpath("G3_flow_direction_changed"))
    out_sinks = str(out_gdb.joinpath("G3_sinks"))

    out_walls = str(folder_path.joinpath("G3_walls.tif"))
    out_fd = str(folder_path.joinpath("G3_d8_filled_angles.tif"))
    out_walled = str(folder_path.joinpath("G3_walled_DEM.tif"))
    out_d8_lyr = str(folder_path.joinpath("G3_d8_filled_angles.tif.lyrx"))

    temp_proj = str(temp_gdb.joinpath("da_bnd_proj"))
    temp_sinks = str(temp_gdb.joinpath("sinks"))
    temp_fill_polys = str(temp_gdb.joinpath("fill_polys"))
    temp_fd_change = str(temp_gdb.joinpath("fd_change_polys"))
    temp_fill_agg = str(temp_gdb.joinpath("fill_agg"))
    temp_fd_agg = str(temp_gdb.joinpath("fd_agg"))
    temp_fd_erase = str(temp_gdb.joinpath("fd_erase"))

    ras_dem = arcpy.Raster(str(dem_path))

    # Project DA boundaries and flowlines to match DEM
    start_wall_prep = perf_counter()
    arcpy.AddMessage(f"{DIVIDER}\nStarted creating breached walls")
    arcpy.AddMessage("Projecting boundary lines and flowlines to match DEM")

    arcpy.Project_management(
        da_line, temp_proj, out_coor_system=ras_dem.spatialReference
    )

    # keep only flowlines that intersect da boundary lines
    lyr_fl = arcpy.MakeFeatureLayer_management(flowline)
    arcpy.SelectLayerByLocation_management(lyr_fl, "INTERSECT", da_line)
    arcpy.Project_management(
        lyr_fl, out_breaching_lines, out_coor_system=ras_dem.spatialReference
    )
    # Create points where boundary lines intersect flowlines
    arcpy.PairwiseIntersect_analysis(
        [da_line, out_breaching_lines], out_breach_points, output_type="POINT"
    )

    # Calculate height field to wall height
    arcpy.CalculateField_management(
        temp_proj, "height", str(wall_height), "PYTHON3", field_type="DOUBLE"
    )

    # Rasterize both line datasets
    arcpy.AddMessage("Rasterizing lines")
    arcpy.CalculateField_management(
        out_breaching_lines, "is_line", "1", "PYTHON3", field_type="SHORT"
    )
    ras_fl = arcpy.PolylineToRaster_conversion(
        out_breaching_lines, "is_line", cell_assignment="MAXIMUM_LENGTH"
    )
    ras_fl = arcpy.Raster(str(ras_fl))

    ras_walls = arcpy.PolylineToRaster_conversion(
        temp_proj, "height", cell_assignment="MAXIMUM_LENGTH"
    )

    # Walling
    with UseExtension("Spatial"):
        # Breach walls with lines by setting overlap to nodata
        arcpy.AddMessage("Breaching DA boundary line walls")
        ras_fl = arcpy.sa.Con(arcpy.sa.IsNull(ras_fl), 0, 1)
        breached_walls = arcpy.sa.SetNull(ras_fl, ras_walls)
        breached_walls.save(out_walls)

        arcpy.AddMessage(
            f"Breached walls created in {timedelta(seconds=perf_counter()-start_wall_prep)}"
        )

        arcpy.AddMessage(f"{DIVIDER}\nStarted creating filled and walled DEMs")
        start_dem_prep = perf_counter()

        # Fill DEM first
        arcpy.AddMessage("Filling DEM")
        ras_first_fill = arcpy.sa.Fill(ras_dem, max_depth)

        # Add DA boundary line to filled DEM to wall DEM
        arcpy.AddMessage("Walling filled DEM")
        ras_walls_zfill = arcpy.sa.Con(
            arcpy.sa.IsNull(breached_walls), 0, breached_walls
        )
        ras_walled = ras_first_fill + ras_walls_zfill
        ras_walled.save(out_walled)

    arcpy.AddMessage(
        f"Filled and walled DEMs created in {timedelta(seconds=perf_counter()-start_dem_prep)}"
    )

    # --------------------------------------------------------------------------------
    # Fill analysis
    # --------------------------------------------------------------------------------
    with UseExtension("Spatial"):
        start_fill = perf_counter()
        arcpy.AddMessage(f"{DIVIDER}\nStarted fill depth analysis")

        # Fill DEM
        arcpy.AddMessage("Filling walled DEM again")
        ras_walled_fill = arcpy.sa.Fill(ras_walled, max_depth)

        # Subtract walled from walled and filled to get fill depth
        arcpy.AddMessage("Getting fill depth")
        ras_depth = ras_walled_fill - ras_walled
        ras_depth_regions = arcpy.sa.SetNull(ras_depth <= min_depth, 1)

    # Convert fill areas to polygons
    arcpy.AddMessage("Converting filled areas to polygons")
    arcpy.RasterToPolygon_conversion(
        ras_depth_regions, temp_fill_polys, simplify="NO_SIMPLIFY"
    )
    # Cluster polygons within specified distance
    arcpy.AddMessage("Clustering filled areas")
    cluster_polys(Path(temp_fill_polys), Path(temp_fill_agg), temp_gdb, cluster_dist)
    arcpy.CalculateGeometryAttributes_management(
        temp_fill_agg, [[FIELD_AREA_SQM, "AREA"]], area_unit="SQUARE_METERS"
    )

    # Filter fill areas by size, distance to DA boundary lines, and distance to breach points
    lyr_fill = arcpy.MakeFeatureLayer_management(temp_fill_agg)
    arcpy.AddMessage("Filtering fill polys by size and distance to DA boundary line")
    select_by_size_and_distance(
        lyr_fill,
        da_line,
        out_breach_points,
        min_area,
        dist_to_breach,
        distance_from_boundary,
    )
    arcpy.CopyFeatures_management(lyr_fill, out_fill_polys)
    arcpy.AddMessage(
        f"Fill analysis completed in {timedelta(seconds=perf_counter()-start_fill)}"
    )

    # --------------------------------------------------------------------------------
    # Flow direction analysis
    # --------------------------------------------------------------------------------
    with UseExtension("Spatial"):
        start_fd = perf_counter()
        arcpy.AddMessage(f"{DIVIDER}\nStarted flow direction change analysis")
        arcpy.AddMessage("D8 Flow direction for filled DEM")
        ras_d8_filled = arcpy.sa.FlowDirection(ras_first_fill)
        remap = arcpy.sa.RemapValue(
            [
                [1, 90],
                [2, 135],
                [4, 180],
                [8, 225],
                [16, 270],
                [32, 315],
                [64, 0],
                [128, 45],
            ]
        )
        ras_d8_reclass = arcpy.sa.Reclassify(ras_d8_filled, "VALUE", remap)
        ras_d8_reclass.save(out_fd)

        # copy layer file
        lyr_file = str(Path(__file__).parent.joinpath("Layers", "d8_arrows.tif.lyrx"))
        shutil.copyfile(lyr_file, out_d8_lyr)

        # Create flow direction raster of walled DEM
        arcpy.AddMessage("D8 flow direction for walled DEM")
        ras_d8_walled = arcpy.sa.FlowDirection(ras_walled)

        # Find where flow directions have been changed by the walls
        arcpy.AddMessage("Finding where flow direction has changed")
        ras_fd_changed = arcpy.sa.SetNull(
            (ras_walls_zfill != 0) | (ras_d8_filled == ras_d8_walled), 1
        )

        # Find sinks created by walls
        arcpy.AddMessage("Finding sinks")
        ras_sinks = arcpy.sa.Sink(ras_d8_walled)

    # Polygonize sinks
    arcpy.AddMessage("Converting sinks to polygons")
    arcpy.RasterToPolygon_conversion(ras_sinks, temp_sinks, simplify="NO_SIMPLIFY")
    # Cluster sinks within specified distance
    cluster_polys(Path(temp_sinks), Path(out_sinks), temp_gdb, cluster_dist)
    arcpy.CalculateGeometryAttributes_management(
        out_sinks, [[FIELD_AREA_SQM, "AREA"]], area_unit="SQUARE_METERS"
    )

    # Convert flow direction changes to polygons
    arcpy.AddMessage("Converting flow direction changes to polygons")
    arcpy.RasterToPolygon_conversion(
        ras_fd_changed, temp_fd_change, simplify="NO_SIMPLIFY"
    )

    # Erase flow direction polys that overlap fill polys
    arcpy.AddMessage("Removing overlap between flow direction change and fill polygons")
    arcpy.PairwiseErase_analysis(temp_fd_change, out_fill_polys, temp_fd_erase)

    # Cluster within specified distance
    arcpy.AddMessage("Clustering flow direction change polys")
    cluster_polys(Path(temp_fd_erase), Path(temp_fd_agg), temp_gdb, cluster_dist)
    arcpy.CalculateGeometryAttributes_management(
        temp_fd_agg, [[FIELD_AREA_SQM, "AREA"]], area_unit="SQUARE_METERS"
    )

    # Filter by size, distance to DA boundary lines, and distance to breach points
    lyr_fd = arcpy.MakeFeatureLayer_management(temp_fd_agg)
    arcpy.AddMessage(
        "Filtering flow change areas by size and distance to DA boundary line"
    )
    select_by_size_and_distance(
        lyr_fd,
        da_line,
        out_breach_points,
        min_area,
        dist_to_breach,
        distance_from_boundary,
    )
    arcpy.CopyFeatures_management(lyr_fd, out_fd_polys)
    arcpy.AddMessage(
        f"Flow direction analysis completed in {timedelta(seconds=perf_counter()-start_fd)}"
    )
    # --------------------------------------------------------------------------------

    # Add review field
    arcpy.AddMessage("Adding review fields")
    for output in (
        out_fd_polys,
        out_breach_points,
        out_breaching_lines,
        out_fill_polys,
        out_sinks,
    ):
        arcpy.AddField_management(output, FIELD_REVIEW, "TEXT")

    arcpy.AddMessage(f"Walled fill analysis completed in {timedelta(seconds=perf_counter()-start_all)}")


if __name__ == "__main__":
    gdb = Path(sys.argv[0])
    da_bnd_line = Path(sys.argv[1])
    flowline = Path(sys.argv[2])
    dem = Path(sys.argv[3])
    wall_height = float(sys.argv[4])
    min_depth = float(sys.argv[5])
    max_depth = float(sys.argv[6])
    min_area = float(sys.argv[7])
    cluster_dist = float(sys.argv[8])
    dist_to_breach = float(sys.argv[9])
    try:
        distance_from_boundary = float(sys.argv[10])
    except:
        distance_from_boundary = None
    main(
        gdb,
        da_bnd_line,
        flowline,
        dem,
        wall_height,
        min_depth,
        max_depth,
        min_area,
        cluster_dist,
        dist_to_breach,
        distance_from_boundary,
    )
