#!/usr/libexec/platform-python

import re
import time
import subprocess
import argparse
import sys
import datetime

class Colors:
    RED = '\033[31m'
    GREEN = '\033[32m'
    YELLOW = '\033[33m'
    BLUE = '\033[34m'
    RESET = '\033[0m'

# Basic helpers
###############################

def get_smc_stats(mode='smcr'):
    try:
        process = subprocess.run([mode, '-d', 's'], stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE, universal_newlines=True, check=True)
    except subprocess.CalledProcessError as e:
        print(f"Error executing command: '{mode} -d s': {e.stderr}")
        return 0, 0, 0, 0, 0, 0

    output = process.stdout

    rx_bytes = re.search(r'RX Stats.*?Data transmitted \(Bytes\)\s+(\d+)', output, re.DOTALL)
    tx_bytes = re.search(r'TX Stats.*?Data transmitted \(Bytes\)\s+(\d+)', output, re.DOTALL)
    rx_reqs = re.search(r'RX Stats.*?Total requests\s+(\d+)', output, re.DOTALL)
    tx_reqs = re.search(r'TX Stats.*?Total requests\s+(\d+)', output, re.DOTALL)
    rx_rmb_bytes = re.search(r'RX Stats.*?Buffer usage \(Bytes\)\s+(\d+)', output, re.DOTALL)
    tx_rmb_bytes = re.search(r'TX Stats.*?Buffer usage \(Bytes\)\s+(\d+)', output, re.DOTALL)

    rx_bytes = int(rx_bytes.group(1)) if rx_bytes else 0
    tx_bytes = int(tx_bytes.group(1)) if tx_bytes else 0
    rx_reqs = int(rx_reqs.group(1)) if rx_reqs else 0
    tx_reqs = int(tx_reqs.group(1)) if tx_reqs else 0
    rx_rmb_bytes = int(rx_rmb_bytes.group(1)) if rx_rmb_bytes else 0
    tx_rmb_bytes = int(tx_rmb_bytes.group(1)) if tx_rmb_bytes else 0

    return rx_bytes, rx_reqs, tx_bytes, tx_reqs, rx_rmb_bytes, tx_rmb_bytes

def get_smc_conn_cnts():
    try:
        res = subprocess.run(['smcss', '-a'], stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE, universal_newlines=True, check=True)
    except subprocess.CalledProcessError as e:
        print(f"Error executing command: 'smcss -a': {e.stderr}")
        return None

    smcr_cnt = len(re.findall(r'SMCR', res.stdout))
    smcd_cnt = len(re.findall(r'SMCD', res.stdout))
    fallback_cnt = len(re.findall(r'TCP', res.stdout))

    return smcr_cnt, smcd_cnt, fallback_cnt

# Transfer Rate
###############################

def convert_rps(rps, raw=False):
    if raw:
        return f"{rps} /s"
    if rps == 0 or rps == -1:
        return f"{rps} /s"
    size_name = ("/s", "K/s", "M/s", "G/s")
    i = 0
    while rps >= 1024 and i < len(size_name)-1:
        rps /= 1024
        i += 1
    return f"{rps:.2f} {size_name[i]}"

def convert_speed(size_bytes, raw=False):
    if raw:
        return f"{size_bytes} B/s"
    if size_bytes == 0 or size_bytes == -1:
        return f"{size_bytes} B/s"
    size_name = ("B/s", "KB/s", "MB/s", "GB/s")
    i = 0
    while size_bytes >= 1024 and i < len(size_name)-1:
        size_bytes /= 1024
        i += 1
    return f"{size_bytes:.2f} {size_name[i]}"

def get_transfer_cnts(mode='smcr'):
    rx_bytes, rx_reqs, tx_bytes, tx_reqs, _, _ = get_smc_stats(mode)
    return rx_bytes, rx_reqs, tx_bytes, tx_reqs

def display_transfer_rates(interval, raw, mode):
    if mode == 'smc':
        modes = ['smcr', 'smcd']
    else:
        modes = [mode]

    display_once = False
    if interval == 0:
        interval = 1
        display_once = True

    prev_trans = {m: get_transfer_cnts(m) for m in modes}
    try:
        while True:
            time.sleep(interval)
            print(f"{Colors.BLUE}{'Date':<22}{'Mode':<5}{'Rx Rate':>18}{'Rx Rps':>18}{'Tx Rate':>18}{'Tx Rps':>18}{Colors.RESET}")
            curr_trans = {m: get_transfer_cnts(m) for m in modes}
            for m in modes:
                curr_rx_bytes, curr_rx_reqs, curr_tx_bytes, curr_tx_reqs = curr_trans[m]
                prev_rx_bytes, prev_rx_reqs, prev_tx_bytes, prev_tx_reqs = prev_trans[m]
                rx_rate = ((curr_rx_bytes - prev_rx_bytes) / interval) if (curr_rx_bytes >= prev_rx_bytes) else -1
                rx_rps = ((curr_rx_reqs - prev_rx_reqs) / interval) if (curr_rx_reqs >= prev_rx_reqs) else -1
                tx_rate = ((curr_tx_bytes - prev_tx_bytes) / interval) if (curr_tx_bytes >= prev_tx_bytes) else -1
                tx_rps = ((curr_tx_reqs - prev_tx_reqs) / interval) if (curr_tx_reqs >= prev_tx_reqs) else -1
                rx_rate_f = f"{convert_speed(rx_rate, raw):>18}"
                rx_rps_f = f"{convert_rps(rx_rps, raw):>18}"
                tx_rate_f = f"{convert_speed(tx_rate, raw):>18}"
                tx_rps_f = f"{convert_rps(tx_rps, raw):>18}"
                current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                print(f"{current_time:<22}{Colors.GREEN}{m:<5}{Colors.RESET}{rx_rate_f}{rx_rps_f}{tx_rate_f}{tx_rps_f}")
            prev_trans = curr_trans
            if display_once:
                break
    except KeyboardInterrupt:
        print("\nProgram exited by user")


# Connection Cnts
###############################

def display_conn_cnts(interval, mode):
    if mode == 'all':
        modes = ['smcr', 'smcd', 'fallback']
    elif mode == 'smc':
        modes = ['smcr', 'smcd']
    else:
        modes = [mode]
    try:
        while True:
            print(f"{Colors.BLUE}{'Date':<22}{'Mode':<8}{'#Conn':>18}{Colors.RESET}")
            smcr_cnt, smcd_cnt, fallback_cnt = get_smc_conn_cnts()
            current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            if 'smcr' in modes:
                print(f"{current_time:<22}{Colors.GREEN}{'smcr':<8}{Colors.RESET}{smcr_cnt:18}")
            if 'smcd' in modes:
                print(f"{current_time:<22}{Colors.GREEN}{'smcd':<8}{Colors.RESET}{smcd_cnt:>18}")
            if 'fallback' in modes:
                print(f"{current_time:<22}{Colors.GREEN}{'fallback':<8}{Colors.RESET}{fallback_cnt:>18}")

            if interval != 0:
                time.sleep(interval)
            else:
                break
    except KeyboardInterrupt:
        print("\nProgram exited by user")


# Ringbufs Stats
###############################

def convert_bufsize(size_bytes, raw=False):
    if raw:
        return f"{size_bytes} B"
    if size_bytes == 0:
        return "0B"
    size_name = ("B", "KB", "MB", "GB")
    i = 0
    while size_bytes >= 1024 and i < len(size_name)-1:
        size_bytes /= 1024
        i += 1
    return f"{size_bytes:.2f} {size_name[i]}"

def get_ringbuf_usages(mode='smcr'):
    _, _, _, _, rx_rmb_bytes, tx_rmb_bytes = get_smc_stats(mode)
    return rx_rmb_bytes, tx_rmb_bytes

def display_ringbuf_usages(interval, raw, mode='smc'):
    if mode == 'smc':
        modes = ['smcr', 'smcd']
    else:
        modes = [mode]

    try:
        while True:
            print(f"{Colors.BLUE}{'Date':<22}{'Mode':<5}{'Rx Bufs':>18}{'Tx Bufs':>18}{Colors.RESET}")
            ringbuf_usages = {m: get_ringbuf_usages(m) for m in modes}
            current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            for m in modes:
                rx_rmb_bytes, tx_rmb_bytes = ringbuf_usages[m]
                rx_rmb_bytes_f = f"{convert_bufsize(rx_rmb_bytes, raw):>18}"
                tx_rmb_bytes_f = f"{convert_bufsize(tx_rmb_bytes, raw):>18}"
                print(f"{current_time:<22}{Colors.GREEN}{m:<5}{Colors.RESET}{rx_rmb_bytes_f}{tx_rmb_bytes_f}")
            if interval != 0:
                time.sleep(interval)
            else:
                break
    except KeyboardInterrupt:
        print("\nProgram exited by user")

def display_base_info(interval, raw, mode='all'):
    if mode == 'all':
        modes = ['smcr', 'smcd', 'fallback']
        stats_modes = ['smcr', 'smcd']
    elif mode == 'smc':
        modes = ['smcr', 'smcd']
        stats_modes = ['smcr', 'smcd']
    elif mode == 'fallback':
        modes = ['fallback']
        stats_modes = []
    else:
        modes = [mode]
        stats_modes = [mode]

    display_once = False
    if interval == 0:
        interval = 1
        display_once = True

    prev_stats = {stats_m: get_smc_stats(stats_m) for stats_m in stats_modes}
    try:
        while True:
            time.sleep(interval)
            print(f"{Colors.BLUE}{'Date':<22}{'Mode':<9}{'Rx Rate':>15}{'Rx Rps':>12}{'Tx Rate':>15}{'Tx Rps':>12}{'#Conn':>10}{'Rx Bufs':>12}{'Tx Bufs':>12}{Colors.RESET}")
            curr_stats = {stats_m: get_smc_stats(stats_m) for stats_m in stats_modes}
            smcr_cnt, smcd_cnt, fallback_cnt = get_smc_conn_cnts()
            for m in modes:
                if m != 'fallback':
                    curr_rx_bytes, curr_rx_reqs, curr_tx_bytes, curr_tx_reqs, rx_rmb_bytes, tx_rmb_bytes= curr_stats[m]
                    prev_rx_bytes, prev_rx_reqs, prev_tx_bytes, prev_tx_reqs, _, _ = prev_stats[m]
                    rx_rate = ((curr_rx_bytes - prev_rx_bytes) / interval) if (curr_rx_bytes >= prev_rx_bytes) else -1
                    rx_rps = ((curr_rx_reqs - prev_rx_reqs) / interval) if (curr_rx_reqs >= prev_rx_reqs) else -1
                    tx_rate = ((curr_tx_bytes - prev_tx_bytes) / interval) if (curr_tx_bytes >= prev_tx_bytes) else -1
                    tx_rps = ((curr_tx_reqs - prev_tx_reqs) / interval) if (curr_tx_reqs >= prev_tx_reqs) else -1
                    rx_rate_f = f"{convert_speed(rx_rate, raw):>15}"
                    rx_rps_f = f"{convert_rps(rx_rps, raw):>12}"
                    tx_rate_f = f"{convert_speed(tx_rate, raw):>15}"
                    tx_rps_f = f"{convert_rps(tx_rps, raw):>12}"
                    rx_rmb_bytes_f = f"{convert_bufsize(rx_rmb_bytes, raw):>12}"
                    tx_rmb_bytes_f = f"{convert_bufsize(tx_rmb_bytes, raw):>12}"
                    if m == 'smcr':
                        conn_cnt = smcr_cnt
                    else:
                        conn_cnt = smcd_cnt
                else:
                    rx_rate_f = f"{'N/A':>15}"
                    rx_rps_f = f"{'N/A':>12}"
                    tx_rate_f = f"{'N/A':>15}"
                    tx_rps_f = f"{'N/A':>12}"
                    rx_rmb_bytes_f = f"{'N/A':>12}"
                    tx_rmb_bytes_f = f"{'N/A':>12}"
                    conn_cnt = fallback_cnt
                current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                print(f"{current_time:<22}{Colors.GREEN}{m:<9}{Colors.RESET}{rx_rate_f}{rx_rps_f}{tx_rate_f}{tx_rps_f}{conn_cnt:>10}{rx_rmb_bytes_f}{tx_rmb_bytes_f}")
            prev_stats = curr_stats
            if display_once:
                break
    except KeyboardInterrupt:
        print("\nProgram exited by user")

def main():
    parser = argparse.ArgumentParser(description="SMC Monitor Tool (Experimental)")
    subparsers = parser.add_subparsers(dest='command', help='commands')

    # Speed command
    parser_speed = subparsers.add_parser('speed', aliases=['s'], help='View transfer rates')
    parser_speed.add_argument('-i', '--interval', type=int, default=0, \
                                help="Interval in seconds to display transfer rates.")
    parser_speed.add_argument('-r', '--raw', action='store_true', \
                                help="Display rates in B/s without converting units.")
    parser_speed.add_argument('-m', '--mode', choices=['smcr', 'smcd', 'smc'], default='smc', \
                                help="Mode to check, either 'smc', 'smcr' or 'smcd', default is 'smc'")

    # Connection command
    parser_conn = subparsers.add_parser('connection', aliases=['c'], help='View connection counts')
    parser_conn.add_argument('-i', '--interval', type=int, default=0, \
                                help="Interval in seconds to display connections.")
    parser_conn.add_argument('-m', '--mode', choices=['smc', 'smcr', 'smcd', 'fallback', 'all'], default='all', \
                                help="Mode to check, either 'all', 'smc', 'smcr', 'smcd' or 'fallback', default is 'all'")

    # Memory command
    parser_memory = subparsers.add_parser('memory', aliases=['m'], help='View memory usages')
    parser_memory.add_argument('-i', '--interval', type=int, default=0, \
                                help="Interval in seconds to display ringbuf usages.")
    parser_memory.add_argument('-r', '--raw', action='store_true', \
                                help="Display memory usages in bytes without converting units.")
    parser_memory.add_argument('-m', '--mode', choices=['smcr', 'smcd', 'smc'], default='smc', \
                                help="Mode to check, either 'smcr', 'smcd' or 'smc', default is 'smc'")

    # Base information, includes {Speed, Connection, Memory}
    parser_base = subparsers.add_parser('base', aliases=['b'], help='View transfer rates, connection counts, and memory usages')
    parser_base.add_argument('-i', '--interval', type=int, default=0, help="Interval in seconds.")
    parser_base.add_argument('-r', '--raw', action='store_true', \
                                help="Display transfer rates/memory usages in bytes without converting units.")
    parser_base.add_argument('-m', '--mode', choices=['smcr', 'smcd', 'smc', 'fallback', 'all'], default='all', help="Mode to check.")

    # Display help message if no arguments are given
    if len(sys.argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)

    args = parser.parse_args()

    if args.command in ['speed', 's']:
        display_transfer_rates(args.interval, args.raw, args.mode)
    elif args.command in ['connection', 'c']:
        display_conn_cnts(args.interval, args.mode)
    elif args.command in ['memory', 'm']:
        display_ringbuf_usages(args.interval, args.raw, args.mode)
    elif args.command in ['base', 'b']:
        display_base_info(args.interval, args.raw, args.mode)

if __name__ == "__main__":
    main()
