import socket
import senti_pb2 as senti
from google.protobuf.internal.decoder import _DecodeVarint32

message_lookup = {
    senti.RAW_MSG: senti.RawMsg(),
    senti.TRIGGER_MSG: senti.TriggerMsg(),
    senti.INPUTCAPTURE_MSG: senti.InputCaptureMsg(),
    senti.GNSS_POS_LLH: senti.GNSSPosLLH(),
    senti.GNSS_POS_ECEF: senti.GNSSPosECEF(),
    senti.GNSS_RELPOS_NED: senti.GNSSRelPosNED(),
    senti.GNSS_RFI_STATUS: senti.GNSSRfiStatus(),
    senti.GNSS_STATUS: senti.GNSSStatus(),
    senti.GNSS_TIME: senti.GNSSTime(),
    senti.GNSS_VEL_ECEF: senti.GNSSVelECEF(),
    senti.GNSS_VEL_NED: senti.GNSSVelNED(),
    senti.IMU_MSG: senti.IMUMsg(),
    senti.IMU_MAG_MSG: senti.IMUMagMsg(),
    senti.IMU_MAG_ORIENTATION_MSG: senti.IMUMagOrientationMsg(),
    senti.INS_ALL_STATES: senti.INSAllStates(),
    senti.INS_IMUBIAS_STATES: senti.INSImuBiasStates()
}

class SentiMessageParser:
    def __init__(self, host, port):
        self.callbacks = {}
        self.host = host
        self.port = port
        self.socket = None

    def connect(self):
        self.socket = socket.socket(socket.AF_INET, # Internet
                                    socket.SOCK_DGRAM) # UDP
        self.socket.bind((self.host, self.port))
    
    def disconnect(self):
        if self.socket:
            self.socket.close()
 
    def register_callback(self, message_id, callback):
        self.callbacks[message_id] = callback

    def receive_message(self):
        data, addr = self.socket.recvfrom(8192) # SentiBoard max buffer size in bytes
        n = 0
        read_byte = 0x0
        sync_id = 0x42

        while n < len(data):
            while(read_byte != sync_id and n < len(data)):
                read_byte, new_pos = _DecodeVarint32(data, n)
                n = new_pos   
            read_byte = 0x0
            msg_id, new_pos = _DecodeVarint32(data, n)
            n = new_pos
            # Find length
            msg_len, new_pos = _DecodeVarint32(data, n)
            n = new_pos
            if msg_len+n > len(data):
                assert(False)

            msg_buf = data[n:n+msg_len]
            
            # Handle message
            self.handle_message(msg_id, msg_buf)

            n+=msg_len

    def handle_message(self, message_id, data):
        if message_id in self.callbacks:
            callback = self.callbacks[message_id]
            pb_msg = message_lookup[message_id]
            pb_msg.ParseFromString(data)
            callback(pb_msg)
        else:
            self.default_message_handler(message_id, data)

    def default_message_handler(self, message_id, data):
        # Default handler for message IDs not registered with its own callback
        # Data can be stored to disk or ignored if it should be omitted.
        pass