import os
import typing
import arcpy
import pathlib
from archydro.streamsegmentation import StreamSegmentation
import wbd_c
import wbd_f

# initialize archydro
streamsegmentation = StreamSegmentation()
streamsegmentation.bCallFromPYT = False
PathLike = typing.Union[str, pathlib.Path]

ACRE_TOM2 = 4046.86
BIG_STREAMS = "Str_big.tif"
FAC_INT = "Fac_int.tif"
FAC_MAX = "Fac_max.tif"
FAC_RANGE = "Fac_range.tif"
FAC_SETNULL = "Fac_setnull.tif"
STRLNK = wbd_c.STRLNK + ".tif"
FAC_POINTS = "Fac_points.shp"
STRLNK_POINTS = "StrLnk_points.shp"
BATCH_POINTS = "proposed_batchpoints"
FAC_RANGESETNULL = "Fac_range_setnull.tif"
INTERSECTION_POLYGONS = "intersection_polygons.shp"
AGGREGATED_POLYGONS = "aggregated_polygons.shp"

# fields to change
GRIDCODE = "gridcode"
GRID_CODE = "grid_code"
FAC_FIELD = "fac"
EXTRACT_FIELD = "RASTERVALU"
STRLNK_FIELD = "strlnk"
JOIN_FID = "JOIN_FID"
POLYGON_ID = "polygon_id"


def get_mincells(
    min_watershedsize_ac: float, ras: "os.PathLike[typing.Any]"
) -> typing.Tuple[float, float, float]:
    """Converts the minimum watershed size from acres to cells.

    Args:
        min_watershedsize_ac (float): Minimum watershed size in acres.
        ras (os.PathLike[typing.Any]): Raster used to get cell size.  Units must be meters.

    Returns:
        typing.Tuple[float, float]: length of cell edge, minimum number of cells.
    """
    x_len = float(arcpy.management.GetRasterProperties(ras, "CELLSIZEX").getOutput(0))
    y_len = float(arcpy.management.GetRasterProperties(ras, "CELLSIZEY").getOutput(0))
    cell_len = max(x_len, y_len)
    cellsize_ac = x_len * y_len / ACRE_TOM2
    return cell_len, min_watershedsize_ac / cellsize_ac


def find_bigstreams(
    fac: "os.PathLike[typing.Any]", streams: "os.PathLike[typing.Any]", n_cells: float
) -> "os.PathLike[typing.Any]":
    """Creates stream raster with streams representing flow accumulation greater than the minimum number of cells.

    Args:
        fac (os.PathLike[typing.Any]): Flow accumulation raster.
        streams (os.PathLike[typing.Any]): Stream raster.
        n_cells (float): Minimum number of cells to define a watershed.

    Returns:
        os.PathLike[typing.Any]: Filepath to updated streams raster.
    """
    where = f"VALUE < {n_cells}"
    big_streams = arcpy.ia.SetNull(fac, streams, where,)
    big_streams.save(os.path.join(out_folder, BIG_STREAMS))
    return big_streams


def prep_fac(
    fac: "os.PathLike[typing.Any]",
    streams: "os.PathLike[typing.Any]",
    out_folder: "os.PathLike[typing.Any]",
) -> "os.PathLike[typing.Any]":
    """Converts flow acumulation to an integer and sets values to Null outside of streams.

    Args:
        fac (os.PathLike[typing.Any]): Flow accumulation raster.
        streams (os.PathLike[typing.Any]): Stream raster.
        out_folder (os.PathLike[typing.Any]): Output folder

    Returns:
        os.PathLike[typing.Any]: Filepath to updated flow accumulation raster.
    """
    fac_int = arcpy.ddd.Int(fac, os.path.join(out_folder, FAC_INT))
    fac_setnull = arcpy.ia.SetNull(streams, fac_int, "VALUE <> 1")
    fac_setnull.save(os.path.join(out_folder, FAC_SETNULL))
    return fac_setnull


def create_range_ras(
    fac: "os.PathLike[typing.Any]",
    streams: "os.PathLike[typing.Any]",
    out_folder: "os.PathLike[typing.Any]",
) -> "os.PathLike[typing.Any]":
    """Calculates the range using focal statistics on a 3x3 grid (at streams only).

    Args:
        fac (os.PathLike[typing.Any]): Flow accumulation raster.
        streams (os.PathLike[typing.Any]): Stream raster.
        out_folder (os.PathLike[typing.Any]): Output folder

    Returns:
        os.PathLike[typing.Any]: Filepath to flow accumulation range raster.
    """
    fac_range = arcpy.sa.FocalStatistics(fac, "Rectangle 3 3 CELL", "RANGE", "DATA")
    fac_range.save(os.path.join(out_folder, FAC_RANGE))
    fac_range_setnull = arcpy.ia.SetNull(streams, fac_range, f"VALUE <> 1")
    fac_range_setnull.save(os.path.join(out_folder, FAC_RANGESETNULL))
    return fac_range_setnull


def create_intersection_polygons(
    facrange_ras: "os.PathLike[typing.Any]",
    n_cells: int,
    cell_len: float,
    out_folder: "os.PathLike[typing.Any]",
) -> "os.PathLike[typing.Any]":
    """Create polygons at stream intersections.

    Args:
        facrange_ras (os.PathLike[typing.Any]): Flow accumulation range raster.
        n_cells (int): Minimimum watershed size in cells.
        cell_len (float): Length of cell edge.
        out_folder (os.PathLike[typing.Any]): Output folder

    Returns:
        os.PathLike[typing.Any]: Filepath to intersection polygons featureclass.

    """
    intersection_polygons = arcpy.RasterToPolygon_conversion(
        facrange_ras,
        os.path.join(out_folder, INTERSECTION_POLYGONS),
        "NO_SIMPLIFY",
        "Value",
        "SINGLE_OUTER_PART",
        None,
    )
    where = """{} >= {}""".format(
        arcpy.AddFieldDelimiters(intersection_polygons, GRIDCODE), str(n_cells)
    )
    polygons_fl = arcpy.MakeFeatureLayer_management(
        intersection_polygons, "polygons_fl", where_clause=where
    )
    return arcpy.AggregatePolygons_cartography(
        polygons_fl,
        os.path.join(out_folder, AGGREGATED_POLYGONS),
        f"{cell_len} Meters",
        "0 SquareMeters",
        "0 SquareMeters",
        "NON_ORTHOGONAL",
    )


def create_batch_points(
    fac_ras: "os.PathLike[typing.Any]",
    strlink_ras: "os.PathLike[typing.Any]",
    intersection_polygons: "os.PathLike[typing.Any]",
    out_folder: "os.PathLike[typing.Any]",
    out_gdb: "os.PathLike[typing.Any]",
) -> "os.PathLike[typing.Any]":
    """Create batch points at stream intersections.

    Args:
        fac_ras (os.PathLike[typing.Any]): Flow accumulation raster.
        strlnk_ras (os.PathLike[typing.Any]): Stream link raster.
        intersection_polygons (os.PathLike[typing.Any]): Intersection polygons
        out_folder (os.PathLike[typing.Any]): Output folder.
        out_gdb (os.PathLike[typing.Any]): Output geodatabase.

    Returns:
        os.PathLike[typing.Any]: Filepath to batchpoints featureclass

    """
    fac_points = arcpy.RasterToPoint_conversion(
        fac_ras, os.path.join(out_folder, FAC_POINTS)
    )
    strlnk_points = arcpy.sa.ExtractValuesToPoints(
        fac_points, strlink_ras, os.path.join(out_folder, STRLNK_POINTS)
    ).getOutput(0)
    with arcpy.EnvManager(overwriteOutput=True):
        points_fl = arcpy.MakeFeatureLayer_management(strlnk_points, "points_fl")
        intersection_polygons_fl = arcpy.MakeFeatureLayer_management(
            intersection_polygons, "intersection_polygons_fl"
        )
    arcpy.SelectLayerByLocation_management(
        points_fl,
        "WITHIN",
        intersection_polygons_fl,
        invert_spatial_relationship="INVERT",
    )
    strlnk_points = arcpy.DeleteFeatures_management(points_fl).getOutput(0)
    with arcpy.EnvManager(overwriteOutput=True):
        strlnk_fl = arcpy.MakeFeatureLayer_management(strlnk_points, "strlnk_fl")
    return arcpy.SpatialJoin_analysis(
        strlnk_fl,
        intersection_polygons_fl,
        os.path.join(out_gdb, BATCH_POINTS),
        join_operation="JOIN_ONE_TO_MANY",
    )


def cleanup_batchpoints(batchpoints_fc: "os.PathLike[typing.Any]"):
    """Deletes bad batch points from batch point feature class.

    Args:
        batchpoints_fc (os.PathLike[typing.Any]): File path to batcpoints feature class.

    Returns:
        None
    """
    with arcpy.EnvManager(overwriteOutput=True):
        batchpoints_fl = arcpy.MakeFeatureLayer_management(
            batchpoints_fc, "batchpoints_fl"
        )
    # TODO rename aliases also
    arcpy.AlterField_management(
        batchpoints_fl, field=GRID_CODE, new_field_name=FAC_FIELD
    )
    arcpy.AlterField_management(
        batchpoints_fl, field=EXTRACT_FIELD, new_field_name=STRLNK_FIELD
    )
    arcpy.AlterField_management(
        batchpoints_fl, field=JOIN_FID, new_field_name=POLYGON_ID
    )
    # get data from all batcpoints using search cursor
    fields = [POLYGON_ID, STRLNK_FIELD, FAC_FIELD, "OID@"]
    data = {}
    to_delete = []
    with arcpy.da.SearchCursor(batchpoints_fl, fields) as cursor:
        for row in cursor:
            polygon_id, str_link, fac, oid = row
            # create key for intersection polygon
            if polygon_id not in data:
                data[polygon_id] = {str_link: (fac, oid)}
            # key for intersection polygon already exists
            else:
                # add stream link
                if str_link not in data[polygon_id]:
                    data[polygon_id][str_link] = (fac, oid)
                # stream link already exists >>> keep biggest fac (i.e point closest to the outlet)
                elif data[polygon_id][str_link][0] < fac:
                    to_delete.append(data[polygon_id][str_link][1])
                    data[polygon_id][str_link] = (fac, oid)
                else:
                    to_delete.append(oid)
    # delete intersections that do not have exactly 3 unique strlnks
    for k, v in data.items():
        if len(v) != 3:
            to_delete.extend([x[1] for x in v.values()])
    data = {k: v for k, v in data.items() if len(v) == 3}
    # delete biggest fac values (i.e. bellow intersection of 2 tributaries)
    to_delete.extend(
        [sorted(list(x.values()))[-1][1] for x in [v for v in data.values()]]
    )
    # use update cursor to actually delete the features
    with arcpy.da.UpdateCursor(batchpoints_fl, ["OID@"]) as cursor:
        for row in cursor:
            if row[0] in to_delete:
                cursor.deleteRow()


if __name__ == "__main__":
    # output file paths
    out_folder = r"D:\OneDrive - DOI\WBD_Collaboration\AK\Work\IfSAR_Updates\hu19060502\19060502_prep\DEM_19060502\test"
    out_gdb = r"D:\OneDrive - DOI\WBD_Collaboration\AK\Work\IfSAR_Updates\hu19060502\hu19060502.gdb"
    fac_ras = r"D:\OneDrive - DOI\WBD_Collaboration\AK\Work\IfSAR_Updates\hu19060502\19060502_prep\DEM_19060502\Fac.tif"
    fdr_ras = r"D:\OneDrive - DOI\WBD_Collaboration\AK\Work\IfSAR_Updates\hu19060502\19060502_prep\DEM_19060502\Fdr_archydro.tif"
    streams_ras = r"D:\OneDrive - DOI\WBD_Collaboration\AK\Work\IfSAR_Updates\hu19060502\19060502_prep\DEM_19060502\Str.tif"
    strlnk_ras = None
    area_acres = 7000

    cell_len, min_cells = get_mincells(area_acres, fac_ras)
    big_streams = find_bigstreams(fac_ras, streams_ras, min_cells)
    fac_ras = prep_fac(fac_ras, big_streams, out_folder)
    if not strlnk_ras:
        _, strlnk_ras = streamsegmentation.execute(
            [big_streams, fdr_ras, os.path.join(out_folder, STRLNK),], None
        )
    facrange_ras = create_range_ras(fac_ras, big_streams, out_folder)
    intersection_polygons = create_intersection_polygons(
        facrange_ras, min_cells, cell_len, out_folder
    )
    batch_points = create_batch_points(
        fac_ras, strlnk_ras, intersection_polygons, out_folder, out_gdb
    )
    cleanup_batchpoints(batch_points)
