import sys
import json
import math
import psycopg2
from h3 import h3
from cgi import parse_qs, escape
from osgeo import ogr, osr


PSQL_HOST = 'localhost'
PSQL_PORT = 5432
PSQL_USER = 'imw'
PSQL_PASSWORD = 'imw'
PSQL_DATABASE = 'imw'
CHUNK = 1
VARIANCE = 0.2


def pull_data(hexids, table, level):
    conn = psycopg2.connect(host=PSQL_HOST,
                            port=PSQL_PORT,
                            user=PSQL_USER,
                            password=PSQL_PASSWORD,
                            database=PSQL_DATABASE)
    cursor = conn.cursor()
    hexids = ', '.join(list(map(lambda x: "'%s'" % x, hexids)))
    query = """
                SELECT data
                from %s
                where hex%s in (%s);

                """ % (table, level, hexids)
    cursor.execute(query)
    data = []
    for row in cursor:
        data.append(json.loads(row[0]))
    conn.close()
    return data


def get_suburbs(wkt):

    conn = psycopg2.connect(host='localhost',
                            port=5432,
                            user='imw',
                            password='imw',
                            database='imw')

    reference = {}
    cursor = conn.cursor()
    cursor.execute("select data from suburbs where st_intersects(geom, ST_SetSRID(ST_GeomFromText('%s'),4326));" % wkt)
    row = cursor.fetchone()
    if not row:
        cursor.close()
        conn.close()
        return []
    row = json.loads(row[0])
    suburbs = []
    suburbs.append(row['name'])
    reference[row['name']] = row['geom_raw']
    outline = row['geom_raw']
    cursor.execute("select data from suburbs where st_intersects(geom, ST_Buffer(ST_SetSRID(ST_GeomFromText('%s'), 4326), 0.001));"%outline)
    for row in cursor:
        name = json.loads(row[0])['name']
        if name not in suburbs:
            suburbs.append(json.loads(row[0])['name'])
            reference[json.loads(row[0])['name']] = json.loads(row[0])['geom_raw']
    cursor.close()
    conn.close()



    return [{'name': x, 'geom_raw': reference[x]} for x in suburbs] 


def pull_area(hexids, level, geom):
    conn = psycopg2.connect(host=PSQL_HOST,
                            port=PSQL_PORT,
                            user=PSQL_USER,
                            password=PSQL_PASSWORD,
                            database=PSQL_DATABASE)
    cursor = conn.cursor()
    hexids = ', '.join(list(map(lambda x: "'%s'" % x, hexids)))
    query = """
                SELECT data, ST_AREA(ST_Transform(ST_Intersection(geom, ST_SetSRID(ST_GeomFromText('%s'),4326)), 3035)) as area
                FROM public.area2
                where ST_Intersects(geom, ST_SetSRID(ST_GeomFromText('%s'),4326));
            """ %(geom, geom)
    cursor.execute(query)
    data = []
    for row in cursor:
        data.append((row[1], json.loads(row[0])))
    conn.close()
    return data


def pull_nearest(hexids, level, geom):
    conn = psycopg2.connect(host=PSQL_HOST,
                            port=PSQL_PORT,
                            user=PSQL_USER,
                            password=PSQL_PASSWORD,
                            database=PSQL_DATABASE)
    cursor = conn.cursor()
    hexids = ', '.join(list(map(lambda x: "'%s'" % x, hexids)))
    query = """
                SELECT data, ST_DISTANCE(ST_Transform(geom, 3035), ST_TRANSFORM(ST_SetSRID(ST_GeomFromText('%s'),4326), 3035)) as distance
                FROM public.nearest
                where hex%s in (%s);
            """ %(geom, level, hexids)
    cursor.execute(query)
    data = []
    for row in cursor:
        data.append((row[1], json.loads(row[0])))
    conn.close()
    return data


def _great_circle_distance(lat1, lat2, lon1, lon2):
        """Calculate the great circle distance from two lat / lon pairs."""
        lon1, lat1, lon2, lat2 = map(math.radians, [lon1, lat1, lon2, lat2])
        return 6371000 * (math.acos(math.sin(lat1) * math.sin(lat2) +
                                    math.cos(lat1) * math.cos(lat2) * math.cos(
                          lon1 - lon2)))

def h3_index(geometry, level):
    """Returns the h3 indexes of grid elements that contain the geometry."""
    ogr.UseExceptions()
    geometry = ogr.CreateGeometryFromWkt(geometry)
    geojson = json.loads(geometry.ExportToJson())
    if geojson['type'] == 'Point':
        return [h3.geo_to_h3(*geojson['coordinates'], level)]
    if geojson['type'] == 'Polygon':
        indices = h3.polyfill(geojson, level)
        # This next step is required as sometimes if the geometry is too small
        # in comparison to the hex size a polyfill can return no hexes
        # that's why we add the hex in which the centroid is located to our set
        centroid_geojson = json.loads(geometry.Centroid().ExportToJson())
        indices.add(h3.geo_to_h3(*centroid_geojson['coordinates'], level))
        return list(indices)


def geometry_buffer(wkt, meters):
    """Returns a buffered geometry.

    Reimplementation from the mmqgis plugin.
    http://michaelminn.com/linux/mmqgis/ - Create Buffers
    """
    if meters == 0:
        return wkt
    try:
        ogr.UseExceptions()
        source_csr = osr.SpatialReference()
        source_csr.ImportFromEPSG(4326)
        #source_csr.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
        geometry = ogr.CreateGeometryFromWkt(wkt)
        lat = geometry.Centroid().GetY()
        lon = geometry.Centroid().GetX()
        target_csr = osr.SpatialReference()
        target_csr.ImportFromProj4(
            "+proj=aeqd +lat_0=%s +lon_0=%s +x_0=0 +y_0=0 +datum=WGS84"
            " +units=m +no_defs" % (lat, lon)
        )
        transformation_st = osr.CoordinateTransformation(
            source_csr, target_csr)
        transformation_ts = osr.CoordinateTransformation(
            target_csr, source_csr)
        geometry.Transform(transformation_st)
        geometry = geometry.Buffer(meters, 7)
        geometry.Transform(transformation_ts)
        if geometry.IsEmpty():
            geometry = ogr.CreateGeometryFromWkt(wkt).Centroid()
        return geometry.ExportToWkt()
    except Exception as e:
        print(e)

def gaussian(distance, max_distance):
    distance = distance/max_distance
    return 1 / math.sqrt(2 * math.pi *
                         VARIANCE *
                         math.exp(-0.5 * distance**2 / VARIANCE))


class MinCollection:
    def __init__(self, data, key='distance', count=3):
        self.data = [data]
        self.key = key
        self.count = count

    def add(self, data):
        self.data += [data]
        self.data = sorted(
            self.data, key=lambda item: item[self.key])

        if len(self.data) > self.count:
            self.data = self.data[:self.count]

    def get(self):
        return self.data





def application(environ,start_response):

    try:
        request_body_size = int(environ.get('CONTENT_LENGTH', 0))
    except (ValueError):
        request_body_size = 0

    request_body = environ['wsgi.input'].read(request_body_size)
    d = parse_qs(request_body)

    targets = d.get(b'targets', None)
    appkey = d.get(b'appkey', None)
    distances = d.get(b'distances', [1000, 2000, 3000])

    distances = sorted(distances, reverse = True)

    if appkey != [b'4ecd61c6-2353-4d31-b482-187156dba034']:
        status = '403 Forbidden'
        response_header = [('Content-type','application/json')]
        start_response(status, response_header)
        return []


    conn = psycopg2.connect(host='localhost',
                        port=5432,
                        user=PSQL_USER,
                        password=PSQL_PASSWORD,
                        database=PSQL_DATABASE)
    cursor = conn.cursor()
    response = []
    targets = json.loads(targets[0].decode())
    for target in targets:
        target = target.rsplit(' ', 1)
        lat = target[1].rsplit(')', 1)[0]
        lon = target[0].rsplit('(', 1)[1]

        distance = 5000
        overshoot = 700
        hexno = 9
        wkt = 'POINT (%s %s)'% (lon, lat)
        h3ind = h3.geo_to_h3(float(lon), float(lat), 9)
        buff_geom = geometry_buffer(wkt, distance + overshoot)
        poly_idx = h3_index(buff_geom, hexno)
        data = pull_data(poly_idx, 'osm', hexno)


        for item in data:
            feature = item['geom_raw'].split(' ')
            item['lat'] = feature[-1].strip(')')
            item['lon'] = feature[-2].split('(')[1]
            item['distance'] = _great_circle_distance(float(lat), float(item['lat']), float(lon), float(item['lon']))

        data = list(filter(lambda x: x['distance'] <= distance, data))
        processed = {'lat': lat, 'lon': lon, 'hex': h3ind,
                     'count_5000' : {},
                     'count_2000' : {},
                     'count_1500' : {},
                     'count_800' : {},
                     'nearest' : {},
                     'area_1000' : {},
                     'area_400' : {},
                     'area_200' : {},
                     'census_400': 0
                     }

        for item in data:
            if item['fclass'] not in processed['count_5000']:
                processed['count_5000'][item['fclass']] = 1
            else:
                processed['count_5000'][item['fclass']] += 1

        distance = 2000
        data = list(filter(lambda x: x['distance'] <= distance, data))
        for item in data:
            if item['fclass'] not in processed['count_2000']:
                processed['count_2000'][item['fclass']] = 1
            else:
                processed['count_2000'][item['fclass']] += 1


        distance = 1500
        data = list(filter(lambda x: x['distance'] <= distance, data))
        for item in data:
            if item['fclass'] not in processed['count_1500']:
                processed['count_1500'][item['fclass']] = 1
            else:
                processed['count_1500'][item['fclass']] += 1


        distance = 800
        data = list(filter(lambda x: x['distance'] <= distance, data))
        for item in data:
            if item['fclass'] not in processed['count_800']:
                processed['count_800'][item['fclass']] = 1
            else:
                processed['count_800'][item['fclass']] += 1


        #census

        distance = 400
        buff_geom = geometry_buffer(wkt, distance + overshoot)
        poly_idx = h3_index(buff_geom, hexno)
        
        census = pull_data(poly_idx, 'census', hexno)


        for item in census:
            feature = item['geom_raw'].split(' ')
            item['lat'] = feature[2].strip(')')
            item['lon'] = feature[1].strip('(')
            item['distance'] = _great_circle_distance(float(lat), float(item['lat']), float(lon), float(item['lon']))
            item['value'] = int(item['value'])
        census = list(filter(lambda x: x['distance'] <= distance, census))

        for item in census:
            processed['census_400'] += item['value']


        #area

        distance = 1000
        buff_geom = geometry_buffer(wkt, distance + overshoot)
        real_geom = geometry_buffer(wkt, distance)
        poly_idx = h3_index(buff_geom, hexno)

        area = pull_area(poly_idx, hexno, real_geom)

        for item in area:
            fclass = item[1]['fclass']
            size = float(item[0])

            if fclass not in processed['area_1000']:
                processed['area_1000'][fclass] = size
            else:
                processed['area_1000'][fclass] += size


        distance = 400
        buff_geom = geometry_buffer(wkt, distance + overshoot)
        real_geom = geometry_buffer(wkt, distance)
        poly_idx = h3_index(buff_geom, hexno)

        area = pull_area(poly_idx, hexno, real_geom)

        for item in area:
            fclass = item[1]['fclass']
            size = float(item[0])

            if fclass not in processed['area_400']:
                processed['area_400'][fclass] = size
            else:
                processed['area_400'][fclass] += size

        distance = 200
        buff_geom = geometry_buffer(wkt, distance + overshoot)
        real_geom = geometry_buffer(wkt, distance)
        poly_idx = h3_index(buff_geom, hexno)

        area = pull_area(poly_idx, hexno, real_geom)

        for item in area:
            fclass = item[1]['fclass']
            size = float(item[0])

            if fclass not in processed['area_200']:
                processed['area_200'][fclass] = size
            else:
                processed['area_200'][fclass] += size


        #nearest
        distance = 50000
        overshoot = 4000
        hexno = 6

        buff_geom = geometry_buffer(wkt, distance + overshoot)
        poly_idx = h3_index(buff_geom, hexno)
        data = pull_nearest(poly_idx, hexno, wkt)


        nearest = []
        for item in data:
            obj = item[1]
            obj['distance'] = float(item[0])
            nearest.append(obj)


        data = list(filter(lambda x: x['distance'] <= distance, nearest))

        for item in data:
            if item['fclass'] == 'suburb_name' and item['fclass'] not in processed['nearest']:
                processed['nearest'][item['fclass']] = MinCollection(item, count=6)
            elif item['fclass'] not in processed['nearest']:
                processed['nearest'][item['fclass']] = MinCollection(item, count=1)
            else:
                processed['nearest'][item['fclass']].add(item)

        for item in processed['nearest']: 
            processed['nearest'][item] = processed['nearest'][item].get() 

        # suburbs
        processed['suburbs'] = get_suburbs(wkt)


        response.append(processed)


    status = '200 OK'
    response_header = [('Content-type','application/json')]
    start_response(status,response_header)
    return [bytes(json.dumps(response), 'utf-8')]
