#!/usr/bin/env python
# Copyright (c) 2014 Pivotal inc.


import datetime
import subprocess
import os
import platform
import sys
import re
from contextlib import closing
from distutils.version import LooseVersion
from optparse import OptionParser
from pygresql import pgdb

gpsd_version = '%prog 1.0'

sysnslist = "('pg_toast', 'pg_bitmapindex', 'pg_temp_1', 'pg_catalog', 'information_schema')"
orca = False
pgoptions = '-c gp_session_role=utility'


def ResultIter(cursor, arraysize=1000):
    'An iterator that uses fetchmany to keep memory usage down'
    while True:
        results = cursor.fetchmany(arraysize)
        if not results:
            break
        for result in results:
            yield result


def getVersion(envOpts):
    cmd = subprocess.Popen('psql --pset footer -Atqc "select version()" template1', shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=envOpts)
    if cmd.wait() is not 0:
        sys.stderr.write('\nError while trying to find HAWQ/GPDB version.\n\n' + cmd.communicate()[1] + '\n\n')
        sys.exit(1)
    return cmd.communicate()[0]


def dumpSchema(connectionInfo, envOpts):
    dmp_cmd = 'pg_dump -h %s -p %s -U %s -s -x --gp-syntax -O %s' % connectionInfo
    p = subprocess.Popen(dmp_cmd, shell=True, stderr=subprocess.PIPE, env=envOpts)
    if p.wait() is not 0:
        sys.stderr.write('\nError while dumping schema.\n\n' + p.communicate()[1] + '\n\n')
        sys.exit(1)


def dumpGlobals(connectionInfo, envOpts):
    if orca:
        dmp_cmd = 'pg_dumpall -h %s -p %s -U %s -l %s -g --no-gp-syntax' % connectionInfo
    else:
        dmp_cmd = 'pg_dumpall -h %s -p %s -U %s -l %s -g --gp-syntax' % connectionInfo
    p = subprocess.Popen(dmp_cmd, shell=True, stderr=subprocess.PIPE, env=envOpts)
    if p.wait() is not 0:
        sys.stderr.write('\nError while dumping globals.\n\n' + p.communicate()[1] + '\n\n')
        sys.exit(1)


def dumpTupleCount(cur):
    stmt = "SELECT pgc.relname, pgn.nspname, pgc.relpages, pgc.reltuples FROM pg_class pgc, pg_namespace pgn WHERE pgc.relnamespace = pgn.oid and pgn.nspname NOT IN " + \
        sysnslist
    setStmt = '-- Table: {1}\n' \
        'UPDATE pg_class\nSET\n' \
        '{0}\n' \
        'WHERE relname = \'{1}\' AND relnamespace = ' \
        '(SELECT oid FROM pg_namespace WHERE nspname = \'{2}\');\n\n'

    cur.execute(stmt)
    columns = map(lambda x: x[0], cur.description)
    types = ['int', 'real']
    for vals in ResultIter(cur):
        print setStmt.format(',\n'.join(map(lambda t: '\t%s = %s::%s' %
                                            t, zip(columns[2:], vals[2:], types))), vals[0], vals[1])


def dumpStats(cur, inclHLL):
    query = 'SELECT pgc.relname, pgn.nspname, pga.attname, pgt.typname, pgs.* ' \
        'FROM pg_class pgc, pg_statistic pgs, pg_namespace pgn, pg_attribute pga, pg_type pgt ' \
        'WHERE pgc.relnamespace = pgn.oid and pgn.nspname NOT IN ' + \
        sysnslist + \
        ' and pgc.oid = pgs.starelid ' \
        'and pga.attrelid = pgc.oid ' \
        'and pga.attnum = pgs.staattnum ' \
        'and pga.atttypid = pgt.oid ' \
        'ORDER BY pgc.relname, pgs.staattnum'
    pstring = '--\n' \
        '-- Table: {0}, Attribute: {1}\n' \
        '--\n' \
        'INSERT INTO pg_statistic VALUES (\n' \
        '{2});\n\n'
    types = ['smallint',  # staattnum
             'boolean',
             'real',
             'integer',
             'real',
             'smallint',
             'smallint',
             'smallint',
             'smallint',
             'smallint',
             'oid',
             'oid',
             'oid',
             'oid',
             'oid',
             'real[]',
             'real[]',
             'real[]',
             'real[]',
             'real[]'
             ]

    cur.execute(query)

    for vals in ResultIter(cur):
        rowVals = ["\t'%s.%s'::regclass" % tuple(vals[1::-1])]

        i = 0
        hll = False
        if vals[3][0] == '_':
            rowTypes = types + [vals[3]] * 5
        else:
            rowTypes = types + [vals[3] + '[]'] * 5
        for val, typ in zip(vals[5:], rowTypes):
            i = i + 1
            str_val = "'%s'" % val
            if i == 10 and (val == 98 or val == 99):
                if inclHLL == False:
                    val = 0
                hll = True

            if val is None:
                val = 'NULL'
            elif isinstance(val, (str, unicode)) and val[0] == '{':
                val = val.replace("'", "''").replace('\\', '\\\\')
                val = "E'" + val + "'"
            if i == 25 and hll == True:
                if inclHLL == True:
                    rowVals.append('\t{0}::{1}'.format(str_val, 'bytea[]'))
                else:
                    rowVals.append('\t{0}'.format('NULL::int4[]'))
            else:
                rowVals.append('\t{0}::{1}'.format(val, typ))
        print pstring.format(vals[0], vals[2], ',\n'.join(rowVals))


def parseCmdLine():
    p = OptionParser(usage='Usage: %prog [options] <database>', version=gpsd_version, conflict_handler="resolve", epilog="WARNING: This tool collects statistics about your data, including most common values, which requires some data elements to be included in the output file. Please review output file to ensure it is within corporate policy to transport the output file.")
    p.add_option('-?', '--help', action='help', help='Show this help message and exit')
    p.add_option('-h', '--host', action='store',
                 dest='host', help='Specify a remote host')
    p.add_option('-p', '--port', action='store',
                 dest='port', help='Specify a port other than 5432')
    p.add_option('-U', '--user', action='store', dest='user',
                 help='Connect as someone other than current user')
    p.add_option('-s', '--stats-only', action='store_false', dest='dumpSchema',
                 default=True, help='Just dump the stats, do not do a schema dump')
    p.add_option('-l', '--hll', action='store_true', dest='dumpHLL',
                 default=False, help='Include HLL stats')
    return p


def isOrcaPresent(versionString):
    if 'HAWQ' in versionString:
        return True
    match = re.search('(?<=Greenplum Database )[0-9.]+', versionString)
    if match and len(match.group()):
        return LooseVersion(match.group()) >= LooseVersion('4.3')
    # No HAWQ and we can't find the GPDB Version
    sys.stderr.write('Cannot determine HAWQ/GPDB version.  Exiting.\n')
    sys.exit(1)


def main():
    parser = parseCmdLine()
    options, args = parser.parse_args()
    if len(args) != 1:
        parser.error("No database specified")
        exit(1)

    # OK - now let's setup all the arguments & options
    db = args[0]
    host = options.host or platform.node()
    user = options.user or os.getlogin()
    port = options.port or '5432'
    inclSchema = options.dumpSchema
    inclHLL = options.dumpHLL

    envOpts = os.environ
    envOpts['PGOPTIONS'] = pgoptions

    version = getVersion(envOpts)

    # Let's update the global variable so we can tell if we're using ORCA aware or not
    global orca
    orca = isOrcaPresent(version)

    if orca:
        envOpts['PGOPTIONS'] = pgoptions + ' -c optimizer=off'

    timestamp = datetime.datetime.today()

    # setup the connection info tuple with options
    connectionInfo = (host, port, user, db)
    connectionString = ':'.join([host, port, db, user, '', pgoptions, ''])

    sys.stdout.writelines(['\n-- Greenplum database Statistics Dump',
                           '\n-- Copyright (C) 2007 - 2014 Pivotal'
                           '\n-- Database: ' + db,
                           '\n-- Date:     ' + timestamp.date().isoformat(),
                           '\n-- Time:     ' + timestamp.time().isoformat(),
                           '\n-- CmdLine:  ' + ' '.join(sys.argv),
                           '\n-- Version:  ' + version + '\n\n'])

    sys.stdout.flush()
    if inclSchema:
        dumpGlobals(connectionInfo, envOpts)
        # Now be sure that when we load the rest we are doing it in the right
        # database
        print '\\connect ' + db + '\n'
        sys.stdout.flush()
        dumpSchema(connectionInfo, envOpts)
    else:
        print '\\connect ' + db + '\n'
        sys.stdout.flush()

    # Now we have to explicitly allow editing of these pg_class &
    # pg_statistic tables
    sys.stdout.writelines(['\n-- ',
                           '\n-- Allow system table modifications',
                           '\n-- ',
                           '\nset allow_system_table_mods=true;\n\n'])
    sys.stdout.flush()

    try:
        with closing(pgdb.connect(connectionString)) as connection:
            with closing(connection.cursor()) as cursor:
                dumpTupleCount(cursor)
                dumpStats(cursor, inclHLL)
        # Upon success, leave a warning message about data collected
        sys.stderr.writelines(['WARNING: This tool collects statistics about your data, including most common values, '
                               'which requires some data elements to be included in the output file.\n',
                               'Please review output file to ensure it is within corporate policy to transport the output file.\n'])
    except pgdb.DatabaseError, err:  # catch *all* exceptions
        sys.stderr.write('Error while dumping statistics:\n')
        sys.stderr.write(err.message)
        sys.exit(1)


if __name__ == "__main__":
        main()
