from synapse.nvparams import *
from synapse.switchboard import *


is_bridge = False
last_seq = 0

is_get_data = False
saw_data_ack = True
data_func_name = ''
get_data_arg_count = 0
get_data_arg1 = None
get_data_addr = ''
get_data_resp = ''
get_data_more = False
sent_get_data_resp = True
get_data_wait_cntr = 0
get_data_sent_cnt = 0
GET_DATA_WAIT_CNTR_MAX = 3
MAX_GET_DATAS = 255

ping_resp_addr = '' # should be small string
responding_addrs = '' # should be a medium string
responding_addrs_index = 0
unackd_ping_ctr = 0
waiting_for_ack = False
my_rp_seq = 0
wait_cntr = 0
noise_cntr = 0
initial_delay_cntr = 0
ttl = 1
sent_addresses = True
ping_data_func = None
bridge_no_response_cntr = 0

MAX_ROUTE_PINGS = 255
MAX_PINGS = 5 # Needs to be bigger than INITIAL_DELAY
MAX_TLL = 5
ESCALATION_THRESH = 20
WAIT_CNTR_MAX = 5
MAX_PKTS = 4
INITIAL_DELAY = 3

MCAST_GROUP = 0x100

PF_VERSION = 12


# Used to test different settings on the fly:
if False:
    def change_ESCALATION_THRESH(new_thresh):
        global ESCALATION_THRESH
        ESCALATION_THRESH = new_thresh

    def change_WAIT_CNTR_MAX(new_thresh):
        global WAIT_CNTR_MAX
        WAIT_CNTR_MAX = new_thresh

    def change_MAX_PKTS(new_thresh):
        global MAX_PKTS
        MAX_PKTS = new_thresh

    def change_INITIAL_DELAY(delay):
        global INITIAL_DELAY
        INITIAL_DELAY = delay

    def clear_addrs():
        global responding_addrs
        responding_addrs = ''

    def change_GET_DATA_WAIT_CNTR_MAX(value):
        global GET_DATA_WAIT_CNTR_MAX
        GET_DATA_WAIT_CNTR_MAX = value

    def change_MAX_TLL(value):
        global MAX_TLL
        MAX_TLL = value

# Functions to help debug the state of this node
if False:
    unable_to_rpa = 0
    unable_to_append = 0
    unable_to_tell = 0

    def get_unable_to_rpa():
        mcastRpc(1,5,"printRtnVal", unable_to_rpa)

    def get_unable_to_append():
        mcastRpc(1,5,"printRtnVal", unable_to_append)

    def get_unable_to_tell():
        mcastRpc(1,5,"printRtnVal", unable_to_tell)

    def get_last_seq():
        mcastRpc(1,5,"printRtnVal", last_seq)

    def get_ping_resp_addr():
        mcastRpc(1,5,"printRtnVal", ping_resp_addr)

    def get_responding_addrs():
        mcastRpc(1,5,"printRtnVal", len(responding_addrs))
        mcastRpc(1,5,"printRtnVal", responding_addrs)

    def get_responding_addrs_index():
        mcastRpc(1,5,"printRtnVal", responding_addrs_index)

    def get_unackd_ping_ctr():
        mcastRpc(1,5,"printRtnVal", unackd_ping_ctr)

    def get_waiting_for_ack():
        mcastRpc(1,5,"printRtnVal", waiting_for_ack)

    def get_noise_cntr():
        mcastRpc(1,5,"printRtnVal", noise_cntr)

    def get_data_arg_count():
        mcastRpc(1,5,"printRtnVal", get_data_arg_count)

def get_data(seq, addr, func_name):
    """Calls the requested function and responds with the return data"""
    get_data_common(seq, addr, func_name, 0, None)

def get_data1(seq, addr, func_name, arg1):
    """Calls the requested function with one argument and responds with the return data"""
    get_data_common(seq, addr, func_name, 1, arg1)

def get_data_common(seq, addr, func_name, arg_count, arg1):
    global is_get_data, last_seq, get_data_addr, data_func_name, saw_data_ack, sent_get_data_resp, get_data_resp
    global get_data_wait_cntr, get_data_sent_cnt, get_data_arg1, get_data_arg_count, get_data_more

    is_get_data = True
    seq &= 0x7fff
    if seq > last_seq or seq == 0:
        clear_ping() # Can't do a ping and get_data at the same time
        if is_bridge and seq == 0:
            if last_seq > -1:
                last_seq = -1
            else:
                last_seq -= 1
        else:
            last_seq = seq

        saw_data_ack = False
        get_data_resp = ''
        get_data_more = False
        sent_get_data_resp = False
        get_data_arg_count = arg_count
        get_data_arg1 = arg1
        if addr == localAddr():
            if arg_count == 0:
                get_data_resp = func_name()
            elif arg_count == 1:
                get_data_resp = func_name(arg1)
            saw_data_ack = True
            if is_bridge:
                sent_get_data_resp = rpc(rpcSourceAddr(), 'get_data_result', seq, addr, get_data_resp, get_data_more)
                if sent_get_data_resp:
                    return
            else:
                sent_get_data_resp = mcastRpc(MCAST_GROUP, 1, 'get_data_ack', seq, addr, get_data_resp, get_data_more)

        get_data_addr = addr
        if is_bridge:
            # Bridge saves requester address in get_data_addr
            get_data_addr += rpcSourceAddr()
        data_func_name = func_name
        last_seq = seq
        get_data_wait_cntr = 0
        get_data_sent_cnt = 0

        if is_bridge:
            # Go ahead and resend right away
            if arg_count == 0:
                mcastRpc(MCAST_GROUP, 1, 'get_data', seq, addr, func_name)
            elif arg_count == 1:
                mcastRpc(MCAST_GROUP, 1, 'get_data1', seq, addr, func_name, arg1)
        #else:
            # Wait in case the node we are asking for is in range of the node so he can respond
    else:
        sent_get_data_resp = False # resend ACK

def get_data_ack(seq, addr, resp, more):
    global saw_data_ack, get_data_addr, get_data_resp, sent_get_data_resp, get_data_more

    if seq >= last_seq:
        if seq > last_seq or not saw_data_ack:
            get_data_more = more
            #mcastRpc(1, 1, "printRtnVal", get_data_more)
            #writePin(1, more)
            if is_bridge:
                # Bridge saves address in get_data_addr
                sent_get_data_resp = rpc(get_data_addr[3:6], 'get_data_result', seq, addr, resp, get_data_more)
            else:
                sent_get_data_resp = mcastRpc(MCAST_GROUP, 1, 'get_data_ack', seq, addr, resp, get_data_more)
        saw_data_ack = True
        if not is_bridge:
            get_data_addr = addr
        elif get_data_more:
            check_get_data_more()
        get_data_resp = resp

def check_get_data_more():
    global last_seq, saw_data_ack

    if get_data_more and sent_get_data_resp:
        last_seq += 1
        sent = False
        if get_data_arg_count == 0:
            sent = mcastRpc(MCAST_GROUP, 1, 'get_data', last_seq, get_data_addr[0:3], data_func_name)
        elif get_data_arg_count == 1:
            sent = mcastRpc(MCAST_GROUP, 1, 'get_data1', last_seq, get_data_addr[0:3], data_func_name, get_data_arg1)
        if sent:
            saw_data_ack = False

def clear_ping():
    global responding_addrs, responding_addrs_index, ping_resp_addr, ping_data_func, sent_addresses

    if not is_bridge:
        #print "clear_ping"
        responding_addrs_index = 0
        ping_resp_addr = None
        responding_addrs = ''
        ping_data_func = None
        sent_addresses = True

def clear_get_data():
    global is_get_data, get_data_addr, data_func_name, get_data_arg_count, get_data_arg1

    is_get_data = False
    get_data_addr = None
    data_func_name = None
    get_data_arg_count = 0
    get_data_arg1 = None

#@setHook(HOOK_STARTUP)
def pf_setup():
    global is_bridge
    needs_reboot = False

    if getInfo(4) == 0:
        # If this is a debug build send out error messages
        crossConnect(DS_ERROR, DS_TRANSPARENT)
        mcastSerial(1, loadNvParam(NV_MESH_MAX_HOPLIMIT_ID))

    if True:
        if loadNvParam(NV_CARRIER_SENSE_ID) != True:
            saveNvParam(NV_CARRIER_SENSE_ID, True)
            needs_reboot = True
        if loadNvParam(NV_COLLISION_DETECT_ID) != False:
            saveNvParam(NV_COLLISION_DETECT_ID, False)
            needs_reboot = True
        if loadNvParam(NV_COLLISION_AVOIDANCE_ID) != False:
            saveNvParam(NV_COLLISION_AVOIDANCE_ID, False)
            needs_reboot = True

    curr_forward_mask = loadNvParam(NV_GROUP_FORWARDING_MASK_ID)
    curr_interest_mask = loadNvParam(NV_GROUP_INTEREST_MASK_ID)
    if not (curr_interest_mask & MCAST_GROUP):
        saveNvParam(NV_GROUP_INTEREST_MASK_ID, curr_interest_mask|MCAST_GROUP)
        needs_reboot = True

    device_type = loadNvParam(NV_DEVICE_TYPE_ID)
    if not device_type:
        device_type = 'Unknown'

    if device_type == 'Bridge':
        if loadNvParam(NV_MESH_ROUTE_AGE_MAX_TIMEOUT_ID) != 0:
            saveNvParam(NV_MESH_ROUTE_AGE_MAX_TIMEOUT_ID, 0)
            needs_reboot = True
        if curr_forward_mask & MCAST_GROUP:
            saveNvParam(NV_GROUP_FORWARDING_MASK_ID, curr_forward_mask^MCAST_GROUP)
            needs_reboot = True
        is_bridge = True
        # DEBUG !!! (SNAPSTICK 100):
        #setPinDir(0, True)
        #setPinDir(1, True)
    else:
        if not (curr_forward_mask & MCAST_GROUP):
            saveNvParam(NV_GROUP_FORWARDING_MASK_ID, curr_forward_mask|MCAST_GROUP)
            needs_reboot = True
        if device_type != 'HalfBridge':
            # Turn off sending multi-casts over PS
            crossConnect(DS_PACKET_SERIAL, DS_NULL)
        is_bridge = False

    #if needs_reboot:
    #    reboot()

    #stdinMode(1, False)

    return needs_reboot

def proxy_ping(seq, data_func):
    do_ping(seq, data_func)

def ping(seq, data_func):
    if not is_bridge:
        do_ping(seq, data_func)

def do_ping(seq, data_func):
    """Global ping request"""
    global ping_resp_addr, last_seq, responding_addrs, outstanding_ping, responding_addrs_index
    global unackd_ping_ctr, waiting_for_ack, my_rp_seq
    global wait_cntr, noise_cntr, ttl, initial_delay_cntr
    global sent_addresses, ping_data_func, bridge_no_response_cntr

    # Check if we have seen this ping before
    if seq != last_seq or seq == 0:
        clear_get_data() # Can't do a ping and get_data at the same time
        if is_bridge and seq == 0:
            if last_seq > -1:
                last_seq = -1
            else:
                last_seq -= 1
        else:
            last_seq = seq

        # Send our response to the ping back to who we originally heard it from
        ping_resp_addr = rpcSourceAddr()

        # Setup to send our ping response when timer fires
        waiting_for_ack = False
        unackd_ping_ctr = 0
        my_rp_seq = 0
        wait_cntr = 0
        noise_cntr = 0
        ttl = 1
        getStat(9) # Reset radio counter
        sent_addresses = False
        ping_data_func = data_func
        if data_func:
            # Call function and add onto reponse
            tmp = data_func()
            # Work around for a SNAPpy string truncation problem:
            #     Taking str(already_a_string) terminates the existing string on the first NUL byte
            # So, for "binary data", don't convert it to a string if it *already is one*
            if len(tmp) is None: # I verified this works and does not set errno() - KRB
                responding_addrs = str(tmp) # Not a string, so OK to convert
            else:
                responding_addrs = tmp # Already a string, so use it as-is
            # Prepend our address and the length of the data onto the actual data
            responding_addrs = localAddr() + chr(len(responding_addrs)) + responding_addrs
        else:
            responding_addrs = localAddr() # Add ourselves to the list of addresses responding to the ping
        responding_addrs_index = len(responding_addrs)

        # Rebroadcast ping request
        if is_bridge:
            mcastRpc(MCAST_GROUP, 1, 'ping', last_seq, ping_data_func)
            initial_delay_cntr = INITIAL_DELAY
            bridge_no_response_cntr = 0
        else:
            initial_delay_cntr = 0
            #mcastRpc(MCAST_GROUP, 1, 'printRtnVal', seq)

        # DEBUG
        if False:
            global unable_to_rpa, unable_to_append, unable_to_tell
            unable_to_rpa = 0
            unable_to_append = 0
            unable_to_tell = 0

def route_ping(seq, next_hop, addrs, rp_seq, orig_ttl, data_func):
    """Called when a node needs a ping response routed"""
    global waiting_for_ack, responding_addrs, sent_addresses

    if next_hop == localAddr() and seq == last_seq:
        # Someone needs us to forward their ping response
        if (len(responding_addrs) + len(addrs)) <= 62: # Check to make sure there is room
            queued = mcastRpc(MCAST_GROUP, orig_ttl, 'rpA', seq, rpcSourceAddr(), rp_seq, data_func)

            # Only add addrs if we can send an ACK, otherwise it will just get added again
            if queued:
                responding_addrs += addrs
                if is_bridge:
                    if rpc(ping_resp_addr, 'tell_ping', last_seq, responding_addrs, True if ping_data_func else False):
                        responding_addrs = ''
                        sent_addresses = True
                        #global unable_to_tell
                        #unable_to_tell += 1
                        #writePin(0, True)
                    #else:
                        #writePin(0, False)
                ##else:
                    ##pass
                    ##We don't forward the ping on until the timer fires
            #else:
                #global unable_to_rpa
                #unable_to_rpa += 1
        #else:
            #global unable_to_append
            #unable_to_append += 1
    elif not is_bridge:
        do_ping(seq, data_func) # Check and make sure we have not seen this ping before

def rpA(seq, addr, rp_seq, data_func):
    """Route Ping ACK"""
    global waiting_for_ack, responding_addrs, unackd_ping_ctr, my_rp_seq

    if addr == localAddr() and seq == last_seq:
        # Someone just ACK'd our route ping request
        if rp_seq == my_rp_seq:
            responding_addrs = responding_addrs[responding_addrs_index:] # Remove addrs that we just sent but keep others
            waiting_for_ack = False
            unackd_ping_ctr = 0
            my_rp_seq += 1
    elif not is_bridge:
        do_ping(seq, data_func) # Check and make sure we have not seen this ping before

#@setHook(HOOK_100MS)
def pf_timer(ms):
    global waiting_for_ack, unackd_ping_ctr, responding_addrs, responding_addrs_index, wait_cntr
    global noise_cntr, initial_delay_cntr, ttl
    global saw_data_ack, sent_get_data_resp, get_data_wait_cntr, get_data_sent_cnt, is_get_data
    global bridge_no_response_cntr

    if is_get_data:
        if get_data_sent_cnt > MAX_GET_DATAS:
            is_get_data = False
        elif not saw_data_ack:
            get_data_wait_cntr += 1
            if get_data_wait_cntr > GET_DATA_WAIT_CNTR_MAX:
                sent = False
                if get_data_arg_count == 0:
                    sent = mcastRpc(MCAST_GROUP, 1, 'get_data', last_seq, get_data_addr[0:3], data_func_name)
                elif get_data_arg_count == 1:
                    sent = mcastRpc(MCAST_GROUP, 1, 'get_data1', last_seq, get_data_addr[0:3], data_func_name, get_data_arg1)
                if sent:
                    get_data_wait_cntr = 0
                    get_data_sent_cnt += 1
        elif not sent_get_data_resp:
            sent_get_data_resp = mcastRpc(MCAST_GROUP, 1, 'get_data_ack', last_seq, get_data_addr[0:3], get_data_resp, get_data_more)
        elif is_bridge and get_data_more:
            #check_get_data_more()
            saw_data_ack = False
    elif responding_addrs:
        if is_bridge:
            if rpc(ping_resp_addr, 'tell_ping', last_seq, responding_addrs, True if ping_data_func else False):
                responding_addrs = ''
                #writePin(0, True)
                #global unable_to_tell
                #unable_to_tell += 1
            else:
                unackd_ping_ctr += 1
                #writePin(0, False)
        elif initial_delay_cntr > INITIAL_DELAY:
            #mcastRpc(MCAST_GROUP, 1, 'printRtnVal', "initial_delay_cntr > INITIAL_DELAY")
            if not waiting_for_ack:
                responding_addrs_index = len(responding_addrs)

            queued = False
            #mcastRpc(MCAST_GROUP, 1, 'printRtnVal', "getStat(9) < MAX_PKTS")
            if unackd_ping_ctr > ESCALATION_THRESH and ttl < MAX_TLL:
                ttl += 1
                unackd_ping_ctr = 0
            if ttl > 1:
                wait_cntr += 1
            else:
                wait_cntr = WAIT_CNTR_MAX+1

            if wait_cntr > WAIT_CNTR_MAX:
                radio_recv_buffs = getStat(9) # auto-clears
                if radio_recv_buffs < MAX_PKTS:
                    wait_cntr = 0
                    queued = mcastRpc(MCAST_GROUP, ttl, 'route_ping', last_seq, ping_resp_addr, responding_addrs[:responding_addrs_index], my_rp_seq, ttl, ping_data_func)
            else:
                noise_cntr += 1
                #mcastRpc(MCAST_GROUP, 1, 'printRtnVal', " noise_cntr += 1")

            if queued:
                waiting_for_ack = True
                unackd_ping_ctr += 1
                if unackd_ping_ctr > MAX_ROUTE_PINGS:
                    clear_ping()
                    #mcastRpc(MCAST_GROUP, 5, 'printRtnVal', "MAX_ROUTE_PINGS")
                #mcastRpc(MCAST_GROUP, 1, 'printRtnVal', "queued")
        else:
            initial_delay_cntr += 1
            #mcastRpc(MCAST_GROUP, 1, 'printRtnVal', initial_delay_cntr)
    elif is_bridge and not sent_addresses:
        if bridge_no_response_cntr > MAX_PINGS:
            pass # We don't clear the ping here in case we end up hearing from someone
        elif mcastRpc(MCAST_GROUP, 1, 'ping', last_seq, ping_data_func):
            bridge_no_response_cntr += 1
            initial_delay_cntr = INITIAL_DELAY
