# (c) Copyright 2010-2011, Synapse Wireless, Inc.


import logging
import timeit
import binascii
import sys
import pprint

from snapconnect import snap

import snapconnect_helpers


log = logging.getLogger()


class PollingTester(object):
    def __init__(self, ping=True, 
                 get_data_addr=None, 
                 get_data_func='random', 
                 get_data_all=False, 
                 serial_type=snap.SERIAL_TYPE_SNAPSTICK100, 
                 serial_port=1, 
                 wait_time=5.0,
                 seq_num=0,
                 discover_function=""):
        self.polling_framework = PollingFrameworkHelper()
        self.polling_framework.ping_wait_time = wait_time
        self.polling_framework.current_sequence_number = seq_num-1

        self.snap_instance = snap.Snap(funcs={'get_data_result': self.polling_framework.get_data_result,
                                              'tell_ping': self.polling_framework.tell_ping,
                                              'get_data_result': self.polling_framework.get_data_result,
                                              'get_data': self.polling_framework.on_get_data,
                                              'rpA': self.polling_framework.on_rpA,
                                             }
                                      )
        self.polling_framework.snap_instance = self.snap_instance
        #self.snap_instance.set_hook(snap.hooks.HOOK_RPC_SENT, self.polling_framework.on_rpc_sent)
        self.snap_instance.save_nv_param(snap.NV_GROUP_INTEREST_MASK_ID, 0x101)

        self.snap_instance.open_serial(serial_type, serial_port)

        self.snap_instance.save_nv_param(snap.NV_MESH_ROUTE_AGE_MAX_TIMEOUT_ID, 0)
        self.snap_instance.save_nv_param(snap.NV_LOCKDOWN_FLAGS_ID, 0x2)

        if get_data_addr is not None:
            try:
                if len(get_data_addr) == 3:
                    self.get_data_addr = get_data_addr
                else:
                    self.get_data_addr = binascii.unhexlify(get_data_addr.replace('.', '').replace('\\x', ''))
                    if len(self.get_data_addr) != 3:
                        raise TypeError
            except TypeError:
                log.critical("You entered an invalid network address")
                sys.exit()
        else:
            self.get_data_addr = get_data_addr
        self.get_data_func = get_data_func
        self.get_data_all = get_data_all
        self.discover_function = discover_function

        if ping:
            log.critical("Sending ping request to find nodes")
            self.polling_framework.start_ping(None, self.on_ping_finished)
        elif get_data_addr:
            log.critical("Sending get_data request to %s" % (binascii.hexlify(self.get_data_addr)))
            self.polling_framework.get_data(self.get_data_addr, self.get_data_func, self.on_get_data_finished)
        elif get_data_all:
            self.polling_framework.get_data_all_nodes(self.get_data_func, self.on_get_data_all_nodes_finished)
        elif self.discover_function:
            self.polling_framework.start_ping(self.discover_function, self.on_ping_discovery_finished)
        else:
            sys.exit()

    def on_ping_finished(self):
        log.critical("Ping finised, saw %i nodes in %.3f seconds" % (len(self.polling_framework.node_addresses), timeit.default_timer()-self.polling_framework.start_time-self.polling_framework.ping_wait_time))

        if self.get_data_addr:
            log.critical("Sending get_data request to %s" % (binascii.hexlify(self.get_data_addr)))
            self.polling_framework.get_data(self.get_data_addr, self.get_data_func, self.on_get_data_finished)
        elif self.get_data_all:
            self.polling_framework.get_data_all_nodes(self.get_data_func, self.on_get_data_all_nodes_finished)
        elif self.discover_function:
            self.polling_framework.start_ping(self.discover_function, self.on_ping_discovery_finished)
        else:
            sys.exit()

    def on_get_data_finished(self, address, response):
        log.critical("Recieved response in %.3f seconds: %s " % (timeit.default_timer()-self.polling_framework.start_time, str(response)))

        if self.get_data_all:
            self.polling_framework.get_data_all_nodes(self.get_data_func, self.on_get_data_all_nodes_finished)
        elif self.discover_function:
            self.polling_framework.start_ping(self.discover_function, self.on_ping_discovery_finished)
        else:
            sys.exit()

    def on_get_data_all_nodes_finished(self, results):
        log.critical("Get data from all nodes finished, saw %i nodes in %.3f seconds" % (len(results), timeit.default_timer()-self.polling_framework.get_data_all_start_time))

        if self.discover_function:
            self.polling_framework.start_ping(self.discover_function, self.on_ping_discovery_finished)
        else:
            sys.exit()

    def on_ping_discovery_finished(self):
        log.critical("Ping with discovery finised, saw %i nodes in %.3f seconds" % (len(self.polling_framework.node_addresses), timeit.default_timer()-self.polling_framework.start_time-self.polling_framework.ping_wait_time))
        sys.exit()


if __name__ == '__main__':
    import optparse, sys

    print "Benchmark app version: %s" % (PollingFrameworkHelper.VERSION)
    print "SNAP Connect version: %s" % (snap.VERSION)

    serport_type = "int"
    serport_default = 0
    if sys.platform.startswith("linux") or sys.platform == 'darwin':
        serport_type = "string"
        serport_default = "/dev/ttyS1"

    parser = optparse.OptionParser()
    parser.add_option("-l", "--logging-level", type="int", dest="logging_level", default=0, help="Logging level (0-2)")
    parser.add_option("-d", "--discover", action="store_false", default=True, dest="ping", help="Disable discovering all nodes")
    parser.add_option("-g", "--get-data", type="string", dest="get_data_addr", default=None, help="Get data from specified address")
    parser.add_option("-f", "--function", type="string", dest="get_data_func", default='random', help="Function name to call on specified address")
    parser.add_option("-a", "--get-data-all", action="store_true", default=False, dest="get_data_all", help="Get data from all nodes found")
    parser.add_option("-n", "--serial_port_number", type=serport_type, dest="serial_port_num", metavar="serial_port_num", default=serport_default, help="Open a connection to the specified port number")
    parser.add_option("-t", "--serial_port_type", type="int", dest="port_type", metavar="port_type", default=snap.SERIAL_TYPE_SNAPSTICK100, help="Open a connection to the specified port type")
    parser.add_option("-w", "--wait-time", type="float", dest="wait_time", default=5.0, help="Time to wait for replies")
    parser.add_option("-s", "--sequence-number", type="int", dest="sequence_number", default=0, help="Starting sequence number")
    parser.add_option("-r", "--discover-function", type="string", dest="discover_function", default="", help="Discovery function name")

    options = parser.parse_args()[0]

    if options.logging_level == 0:
        logging_level = logging.CRITICAL
    elif options.logging_level == 1:
        logging_level = logging.INFO
    else:
        logging_level = logging.DEBUG

    logging.basicConfig(level=logging_level, format='%(asctime)s:%(msecs)03d %(levelname)-8s %(name)-8s %(message)s', datefmt='%H:%M:%S')

    try:
        PollingTester(ping=options.ping, 
                      get_data_addr=options.get_data_addr, 
                      get_data_func=options.get_data_func, 
                      get_data_all=options.get_data_all,
                      serial_type=options.port_type, 
                      serial_port=options.serial_port_num,
                      wait_time=options.wait_time,
                      seq_num=options.sequence_number,
                      discover_function=options.discover_function).snap_instance.loop()
    except SystemExit:
        pass
