#!/bin/bash

VERSION="1.8.3"
BPFT="bpftool"	# use system bpftool first if installed,
                # otherwise, use bpftool in current directory.
BPFPROG_NAME="anolis_smc"
BPFPROG_FILE="anolis_smc"
IP_POLICY="smc_strats_ip"
PORT_POLICY="smc_strategies"

default_config=(
    "0 1"
    "20 0"
    "21 0"
    "22 0"
    "23 0"
    "53 0"
    "68 0"
    "69 0"
    "80 0"
    "110 0"
    "111 0"
    "115 0"
    "123 0"
    "443 0"
    "2049 0"
)

init_value() {
    rtt_threshold=48
    smc_productivity="3,14,20,0"
    tcp_productivity="14,18,20,0"
    max_credits=4096
    min_credits=128
    initial_credits=512
    max_pacing_burst=128
    pacing_delta=1
}

# 参数$1为要转换的数字，$2为字节数
# Example: endian 0x18 4
# output: 0x18 0x00 0x00 0x00
endian() {
    local nums=$1
    local byte_size=$2

    if [ $(echo -n "1" | od -to2 | head -n1 | awk '{print $2}') = "0000000" ]; then
        for num in $(echo $nums | sed "s/,/ /g"); do
            for ((i = byte_size - 1; i >= 0; i--)); do
                byte=$(($num >> (8 * $i) & 0xFF))
                echo -n "0x$(printf "%02X" $byte) "
            done
        done
    else
        for num in $(echo $nums | sed "s/,/ /g"); do
            for ((i = 0; i < byte_size; i++)); do
                byte=$(($num >> (8 * $i) & 0xFF))
                echo -n "0x$(printf "%02X" $byte) "
            done
        done
    fi
}

# Convert IPv4 address to big endian __be32
# Example: ipv4_aton 192.168.122.33
# Output: 0x217AA8C0
ipv4_aton() {
	local IFS='.'
	local ip=$1
	local -a octets
	local addr
	local octet

	# seperate IP to an array
	read -ra octets <<< "$ip"

	# check if IPv4 address includes 4 parts
	if [ ${#octets[@]} -ne 4 ]; then
		echo "error"
		return
	fi

	# check if each part is between 0-255
	for octet in "${octets[@]}"; do
		if ! [[ "$octet" =~ ^[0-9]+$ ]]; then
			echo "error"
			return
		fi
		if [ "$octet" -lt 0 ] || [ "$octet" -gt 255 ]; then
			echo "error"
			return
		fi
	done

	addr=$(( (${octets[3]} << 24) | (${octets[2]} << 16) | \
		(${octets[1]} << 8) | ${octets[0]} ))
	printf "0x%X\n" "$(( addr & 0xFFFFFFFF ))"
}

ipv4_verify_mask() {
	local mask="$1"

	if [ -z "$mask" ]; then
		mask="32"	# if not set, use default '/32'
	fi
	if ! [[ "$mask" =~ ^[0-9]+$ ]]; then
		echo "error"
		return
	fi
	if [ "$mask" -lt 0 ] || [ "$mask" -gt 32 ]; then
		echo "error"
		return
	fi
	echo "$mask"
}

# Example usage: ipv4_add_map "192.168.122.33" "24" "02"
ipv4_add_map() {
	local ipv4="$1"
	local ipv4_hex
	local ipv4_mask="$2"
	local mode="$3"

	ipv4_hex=$(ipv4_aton "$ipv4")
	if [[ ${ipv4_hex} == "error" ]]; then
		echo "IP format is incorrect. Only IPv4 addresses are accepted"
		exit 1
	fi
	ipv4_mask=$(ipv4_verify_mask "$ipv4_mask")
	if [[ ${ipv4_mask} == "error" ]]; then
		echo "IP mask is incorrect. Please use correct IPv4 mask"
		exit 1
	fi
	# Example:
	# bpftool map update name smc_starts_ip \
	#	key 0x18 0x00 0x00 0x00 0xC0 0xA8 0x7A 0x21 value 0x02
	${BPFT} map update name ${IP_POLICY} \
		key $(endian $ipv4_mask 4) $(endian $ipv4_hex 4) \
		value $(endian $mode 1)
}

# Example output of 'bpftool map dump name smc_strats_ip'
#   # bpftool map dump name smc_strats_ip
#   key: 10 00 00 00 c0 a7 03 00  value: 02
#   key: 18 00 00 00 c0 a8 7a 21  value: 02
#   key: 10 00 00 00 c0 a8 03 01  value: 02
#   Found 3 elements
#
# We need to change it to
#   key:     192.167.3.0/16       value:   "pass"
#   key:     192.168.122.33/24    value:   "pass"
#   key:     192.168.3.1/16       value:   "pass"
#
# Example usage: bpftool map dump name smc_strats_ip | ipv4_dump_map
ipv4_dump_map() {
	local format="%-8s %-20s %-8s %-10s\n"

        while read -r line; do
                # Check if the line contains the keys 'key' and 'value'
                if echo "$line" | grep -q 'key:' && \
                        echo "$line" | grep -q 'value:'; then
                        # Extract the subnet prefix and IP address
                        mask_hex=$(echo "$line" | awk '{
                                for (i=1; i<=NF; i++)
                                        if ($i == "key:")
                                                print $(i+4)$(i+3)$(i+2)$(i+1)
                                }')

                        ip_hex=$(echo "$line" | awk '{
                                for (i=1; i<=NF; i++)
                                        if ($i == "key:")
                                                print $(i+5)$(i+6)$(i+7)$(i+8)
                                }')
                        ip=$(printf "%d.%d.%d.%d/%d\n" \
                                0x${ip_hex:0:2} 0x${ip_hex:2:2} \
                                0x${ip_hex:4:2} 0x${ip_hex:6:2} \
                                0x$mask_hex)

                        # Extract and convert the value
                        value_hex=$(echo "$line" | awk '{
                                for (i=1; i<=NF; i++)
                                        if ($i == "value:")
                                                print $(i+1)
                                }')
                        case "$value_hex" in
                                "00") value_str="denied" ;;
                                "01") value_str="auto" ;;
                                "02") value_str="pass" ;;
                                *) value_str="unknown" ;;
                        esac

                        # Output the result
                        printf "$format" \
                                "key:" "${ip}" "value:" "\"${value_str}\""
                fi
        done
}

# Example usage: ipv4_delete_map "192.168.122.33" "24"
ipv4_delete_map() {
	local ipv4="$1"
	local ipv4_hex
	local ipv4_mask="$2"

	ipv4_hex=$(ipv4_aton "$ipv4")
	if [[ ${ipv4_hex} == "error" ]]; then
		echo "IP format is incorrect. Only IPv4 addresses are accepted"
		exit 1
	fi
	ipv4_mask=$(ipv4_verify_mask "$ipv4_mask")
	if [[ ${ipv4_mask} == "error" ]]; then
		echo "IP mask is incorrect. Please use correct IPv4 mask"
		exit 1
	fi

	# Example:
	# bpftool map delete name smc_strats_ip key \
	#	0x18 0x00 0x00 0x00 0xc0 0xa8 0x7a 0x21
	${BPFT} map delete name ${IP_POLICY} \
		key $(endian $ipv4_mask 4) $(endian $ipv4_hex 4)
}

# Example output of 'bpftool map dump name smc_strats_ip'
#   # bpftool map dump name smc_strats_ip
#   key: 10 00 00 00 c0 a7 03 00  value: 02
#   key: 18 00 00 00 c0 a8 7a 21  value: 02
#   key: 10 00 00 00 c0 a8 03 01  value: 02
#   Found 3 elements
#
# Example usage: bpftool map dump name smc_strats_ip | ipv4_delete_map_raw
ipv4_delete_map_raw() {
	local key_hex_array

	while read -r line; do
		if echo "$line" | grep -q 'key:'; then
			# extract the next 8 bytes after 'key:'
			key_hex_array=$(echo "$line" | awk '/key:/ {
			for (i=2; i<=9; i++)
				printf "0x%s ", $i
				print ""
			}')
			# Example
			# bpftool map delete name smc_strats_ip key \
			#	0x18 0x00 0x00 0x00 0xc0 0xa8 0x7a 0x21
			${BPFT} map delete \
				name ${IP_POLICY} \
				key $key_hex_array
		fi
	done
}

usage() {
    local errMsg=$1

    if [ x"${errMsg}" != x ]; then
        echo "smc-ebpf: error: ${errMsg}"
    fi

    echo " Usage: smc-ebpf COMMAND"
    echo "        smc-ebpf policy               policy for connecting smc"
    echo "        smc-ebpf version              display version information"
    echo " Examples: "
    echo "    smc-ebpf policy"

    exit 1
}

policy_usage() {

    local errMsg=$1

    if [ x"${errMsg}" != x ]; then
        echo "smc-ebpf: error: ${errMsg}"
    fi

    echo " Usage: smc-ebpf policy COMMAND [OPTIONS]"
    echo "        smc-ebpf policy load [OPTIONS]    load policy "
    echo "    --init                                load policy with pre-defination config"
    echo "        smc-ebpf policy stop              stop policy "
    echo "        smc-ebpf policy unload            unload policy "
    echo "        smc-ebpf policy init              init policy with default config"
    echo "        smc-ebpf policy clear             clear all policy config"
    echo "        smc-ebpf policy dump              display all policy config"
    echo "        smc-ebpf policy config [OPTIONS]  config policy"
    echo "        smc-ebpf policy delete [OPTIONS]  delete policy"
    echo "    --ip [IPv4]                           target IPv4 address"
    echo "    --port                                target port"
    echo "    --mode [auto|disable|enable]          target mode"
    echo " Examples: "
    echo "    smc-ebpf policy load"
    echo "    #disable port 80 to use smc"
    echo "    smc-ebpf policy config --port 80 --mode disable "
    echo "    #delete ip 192.168.0.0/24 policy"
    echo "    smc-ebpf policy delete --ip 192.168.0.0 --mask 24 "

    exit 1
}

policy_help() {
    policy_usage
}

policy_config() {

    local opts="port:,ip:,mask:,mode:,auto,enable,disable"
    local port ip mask mode

    # parse config
    local parsed_args=$(getopt --long ${opts} -- "$@")
    eval set -- "$parsed_args"
    while [ -n "$1" ]; do
        case "$1" in
        --port)
            port=$2
            shift 2
            ;;
	--ip)
	    ip=$2
	    shift 2
	    ;;
	--mask)
	    mask=$2
	    shift 2
	    ;;
        --mode)
            case "$2" in
                auto)
                    mode=1
                    ;;
                enable)
                    mode=2
                    ;;
                disable)
                    mode=0
                    ;;
                *)
                    policy_usage "unrecognized mode: $2"
                    ;;
            esac
            shift 2
            ;;
        --auto)
            mode=1
            shift
            ;;
        --enable)
            mode=2
            shift
            ;;
        --disable)
            mode=0
            shift
            ;;
        --)
            shift
            break
            ;;
        *)
            policy_usage "unrecognized OPTION: $1"
            ;;
        esac
    done

    if [ -n "$ip" ] && [ -n "$port" ]; then
	policy_usage "Error: ip and port cannot be set at the same time."
	exit 1
    fi

    # update config
    if [ -n "$ip" ]; then
	    ipv4_add_map "$ip" "$mask" "$mode"
    fi

    if [ -n "$port" ]; then
	    init_value
	    ${BPFT} map update name ${PORT_POLICY} \
		key $(endian $port 2) \
		value \
		$(endian $mode 1) 0x00 \
		$(endian $rtt_threshold 2) \
		$(endian $smc_productivity 2) \
		$(endian $tcp_productivity 2) \
		$(endian $max_credits 4) \
		$(endian $min_credits 4) \
		$(endian $initial_credits 4) \
		$(endian $max_pacing_burst 4) \
		0x00 0x00 0x00 0x00 \
		$(endian $pacing_delta 8)
    fi
}

policy_delete() {

    local opts="port:,ip:,mask:"
    local port ip mask

    # parse config
    local parsed_args=$(getopt --long ${opts} -- "$@")
    eval set -- "$parsed_args"
    while [ -n "$1" ]; do
        case "$1" in
        --port)
            port=$2
            shift 2
            ;;
	--ip)
	    ip=$2
	    shift 2
	    ;;
	--mask)
	    mask=$2
	    shift 2
	    ;;
        --)
            shift
            break
            ;;
        *)
            policy_usage "unrecognized OPTION: $1"
            ;;
        esac
    done

    if [ -n "$ip" ] && [ -n "$port" ]; then
        policy_usage "Error: IP and port cannot be set at the same time."
        exit 1
    fi

    # update config
    if [ -n "$ip" ]; then
	    ipv4_delete_map "$ip" "$mask" "$mode"
    fi

    if [ -n "$port" ]; then
	    ${BPFT} map delete name ${PORT_POLICY} key $(endian $port 2)
    fi
}

# clear all config
policy_clear() {
    # clear port policy
    local keys=($(${BPFT} map dump name ${PORT_POLICY} \
	    | grep -oP '(?<= "key": )\d+'))
    for key in "${keys[@]}"; do
        ${BPFT} map delete name ${PORT_POLICY} key $(endian $key 2)
    done

    # clear IPv4 policy
    ${BPFT} map dump name ${IP_POLICY} | ipv4_delete_map_raw
}

# unload ebpf program
policy_unload() {
    policy_clear
    ${BPFT} struct_ops unregister name ${BPFPROG_NAME}
}

policy_init() {
    for def in "${default_config[@]}"; do
        eval set -- ${def}
        init_value
        ${BPFT} map update name ${PORT_POLICY} \
	    key $(endian $1 2) \
	    value \
            $(endian $2 1) 0x00 \
            $(endian $rtt_threshold 2) \
            $(endian $smc_productivity 2) \
            $(endian $tcp_productivity 2) \
            $(endian $max_credits 4) \
            $(endian $min_credits 4) \
            $(endian $initial_credits 4) \
            $(endian $max_pacing_burst 4) \
            0x00 0x00 0x00 0x00 \
            $(endian $pacing_delta 8)
    done
    # ip strategies won't be involved.
}

policy_dump() {
    # dump IPv4 policy
    ${BPFT} map dump name ${IP_POLICY} \
	| ipv4_dump_map
    # dump port policy
    ${BPFT} map dump name ${PORT_POLICY}
}

# load ebpf program
policy_load() {

    local init=0
    local opts="init"

    # parse config
    parsed_args=$(getopt --long ${opts} -- "$@")
    eval set -- "$parsed_args"
    while [ -n "$1" ]; do
        case $1 in
        --init)
            init=1
            shift
            ;;
        --)
            shift
            ;;
        *)
            policy_usage "unrecognized OPTION: $1"
            ;;
        esac
    done

    #check if ebpf prog was already loaded
    ${BPFT} struct_ops | grep ${BPFPROG_NAME} >/dev/null 2>&1
    if [ x"$?" == x'0' ]; then
        if [ $init -eq 1 ]; then
            policy_init
        fi
	return;
    fi

    local DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
    ${BPFT} struct_ops register $DIR/${BPFPROG_FILE}

    if [ $init -eq 1 ]; then
	policy_init
    fi
}

policy_op() {

    # required super privileged
    if [ $(id -u) -ne 0 ]; then
        policy_usage "required root privileged"
    fi

    # adjust bpftool path
    if command -v bpftool >/dev/null 2>&1; then
        BPFT=$(command -v bpftool)
    else
        if [ -x "./bpftool" ]; then
	    BPFT="./bpftool"
        else
	    policy_usage "bpftool is not installed."
        fi
    fi


    case $op in
    load | unload | stop | init | clear | config | delete | help | dump) ;;
    "")
	policy_usage
	;;
    *)
        policy_usage "unrecognized COMMAND: ${op}"
        ;;
    esac
    eval policy_${op} "${op}" "$@"
}

# main
obj="$1"
shift
op="$1"	# global
shift

case $obj in
    policy)
        policy_op "$@"
        ;;
    version)
        echo "smc-ebpf utility, smc-tools-$VERSION"
        ;;
    "")
        usage
        ;;
    *)
        usage "unrecognized: ${obj}"
        ;;
esac
