#!/usr/bin/env python

# Copyright (C) 2010  Internet Systems Consortium, Inc. ("ISC")
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
# OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.
#

"""
AFTR remote configuration server
"""
__version__ = "$Id: aftrconf.py 719 2010-07-02 19:34:37Z pselkirk $"

transport = "http"
#transport = "socket"
#transport = "debug"

import sys
import time
import socket
import getopt
import shutil
from lxml import etree

################ configuration dicts ################

class NatEntry(object):
    """static NAT entry"""
    __slots__ = ['parent', 'protocol', 'saddr', 'sport', 'nport']
    __hash__ = None

    def __init__(self, parent, protocol, saddr, sport, nport):
        """initializer"""
        self.parent = parent
        self.protocol = protocol
        self.saddr = saddr
        self.sport = sport
        self.nport = nport

    def __eq__(self, other):
        """equality"""
        if not isinstance(other, NatEntry):
            return False
        if self.saddr != other.saddr:
            return False
        if self.sport != other.sport:
            return False
        if self.protocol != other.protocol:
            return False
        if self.nport != other.nport:
            return False
        return True

    def __ne__(self, other):
        """inequality"""
        return not self.__eq__(other)

    def __cmp__(self, other):
        """compare"""
        if self == other:
            return 0
        if not isinstance(other, NatEntry):
            return 1
        key1 = socket.inet_pton(socket.AF_INET, self.saddr)
        key2 = socket.inet_pton(socket.AF_INET, other.saddr)
        if key1 != key2:
            return cmp(key1, key2)
        if self.sport != other.sport:
            return cmp(self.sport, other.sport)
        if self.protocol != other.protocol:
            return cmp(self.protocol, other.protocol)
        if self.nport != other.nport:
            return cmp(self.nport, other.nport)
        raise ValueError, 'NatEntry.__cmp__'

    def totext(self, file_):
        """export to text"""
        print >> file_, 'nat', self.parent.remote, self.protocol, self.saddr, \
            self.sport, self.parent.addr, self.nport

    def toxml(self, parent):
        """export to xml"""
        nat = etree.SubElement(parent, 'natEntry')
        etree.SubElement(nat, 'tunnel').text = self.parent.remote
        etree.SubElement(nat, 'protocol').text = self.protocol
        etree.SubElement(nat, 'sourceAddress').text = self.saddr
        etree.SubElement(nat, 'sourcePort').text = self.sport
        etree.SubElement(nat, 'nattedAddress').text = self.parent.addr
        etree.SubElement(nat, 'nattedPort').text = self.nport

    def hashkey_src(self):
        return (self.parent.remote, self.saddr, self.sport, self.protocol)

    def hashkey_nat(self):
        return (self.parent.addr, self.nport, self.protocol)

class Tunnel(object):
    """tunnel entry"""
    __slots__ = ['remote', 'addr', 'policies', 'entries']
    __hash__ = None

    def __init__(self, remote, addr):
        """initializer"""
        self.remote = remote
        self.addr = addr
        self.policies = ""
        self.entries = []

    def __eq__(self, other):
        """shallow equality"""
        if not isinstance(other, Tunnel):
            return False
        if self.remote != other.remote:
            return False
        if self.addr != other.addr:
            return False
        return True

    def __ne__(self, other):
        """shallow inequality"""
        return not self.__eq__(other)

    def __cmp__(self, other):
        """compare"""
        if self == other:
            return 0
        if not isinstance(other, Tunnel):
            return 1
        key1 = socket.inet_pton(socket.AF_INET6, self.remote)
        key2 = socket.inet_pton(socket.AF_INET6, other.remote)
        if key1 != key2:
            return cmp(key1, key2)
        if not isinstance(other, Tunnel):
            return 1
        key1 = socket.inet_pton(socket.AF_INET, self.addr)
        key2 = socket.inet_pton(socket.AF_INET, other.addr)
        if key1 != key2:
            return cmp(key1, key2)
        raise ValueError, 'Tunnel.__cmp__'

    def totext(self, file_):
        """export to text"""
        if len(self.entries) == 0:
            print >> file_, 'tunnel', self.remote, self.addr
        else:
            for entry in self.entries:
                entry.totext(file_)
        if len(self.policies) > 0:
            print >> file_, self.policies

    def toxml(self, parent):
        """export to xml"""
        for entry in self.entries:
            entry.toxml(parent)

class PoolAddr(object):
    """ information about a managed IPv4 address """
    __slots__ = ['addr', 'tcpmin', 'tcpmax', 'udpmin', 'udpmax']
    # each pool could have a dict of assigned tunnels,
    # if we ever wanted to look things up that way

    def __init__(self, addr, tcpmin, tcpmax, udpmin, udpmax):
        """initializer"""
        self.addr = addr
        self.tcpmin = tcpmin
        self.tcpmax = tcpmax
        self.udpmin = udpmin
        self.udpmax = udpmax

    def copy(self):
        return PoolAddr(self.addr, self.tcpmin, self.tcpmax,
                        self.udpmin, self.udpmax)

class Conf(dict):
    """conf state as a dict extension"""

    def totext(self, file_):
        """export to text"""
        for elem in sorted(self.values()):
            elem.totext(file_)

    def toxml(self, parent):
        """export to xml"""
        for elem in sorted(self.values()):
            elem.toxml(parent)

    def gettunnel(self, ipv6, addr):
        """get or create a tunnel entry"""
        tunnel = self.get(ipv6)
        if tunnel is not None:
            if (addr is not None) and (tunnel.addr != addr):
                raise ValueError, 'tunnel natted mismatch'
        elif addr is not None:
            tunnel = Tunnel(ipv6, addr)
            self[ipv6] = tunnel
        return tunnel

CONFNAT = {}
CONFTUN = Conf()
CONFPOOL = {'default': PoolAddr('default', 2048, 65535, 512, 65535)}

################ config file routines ################

CONFFILE = None

def canonv4(text):
    """canonicalize an IPv4 address"""
    if text is None:
        return None
    try:
        addr = socket.inet_pton(socket.AF_INET, text)
        return socket.inet_ntop(socket.AF_INET, addr)
    except socket.error:
        #print 'canonv4 failed on', text
        return None

def canonv6(text):
    """canonicalize an IPv6 address"""
    if text is None:
        return None
    try:
        addr = socket.inet_pton(socket.AF_INET6, text)
        return socket.inet_ntop(socket.AF_INET6, addr)
    except socket.error:
        #print 'canonv6 failed on', text
        return None

def canonport(text):
    """canonicalize a port number"""
    try:
        port = int(text)
        if (port <= 0) or (port > 65535):
            return None
        return str(port)
    except ValueError:
        #print 'canonport failed on', text
        return None

class ConfigFile:
    """parse aftr.conf"""

    fd_marker = '#### Everything below this line is ' + \
        'subject to rewriting by aftrconf.py ####'

    def __init__(self, name):
        """initializer - parse config file"""
        try:
            self.fd = open(name, "r+")
        except IOError as err:
            # most likely ENOENT - No such file or directory
            print err[1]
            sys.exit(1)
        try:
            shutil.copy2(name, name + '~')
        except IOError as err:
            # most likely EACCES - Permission denied
            print err[1]
            sys.exit(1)
        self.fd_rewrite = 0
        self.fd_section2 = ""
        fd_pos = 0
        while True:
            fd_pos = self.fd.tell()
            line = self.fd.readline()
            if not line:
                break
            self.cf_parse_line(line.rstrip(), fd_pos)
        # leave file open for rewriting
        self.rewrite()
        #consistency()	# XXX debug

    def cf_parse_line(self, text, fd_pos):
        """parse a line"""
        if text == self.fd_marker:
            if self.fd_rewrite == 0:
                self.fd_rewrite = fd_pos
            return
        args = text.split()
        if len(args) != 0:
            if args[0] == 'pool':
                self.cf_parse_pool(args)
            elif args[0] == 'default' and args[1] == 'pool':
                self.cf_parse_defpool(args)
            elif args[0] == 'nat':
                if self.fd_rewrite == 0:
                    self.fd_rewrite = fd_pos
                self.cf_parse_nat(args)
            elif args[0] == 'tunnel':
                if self.fd_rewrite == 0:
                    self.fd_rewrite = fd_pos
                self.cf_parse_tunnel(args)
            elif (args[0] == 'mss' or
                  args[0] == 'mtu' or
                  args[0] == 'toobig'):
                self.cf_parse_policies(args)
            elif (args[0] == 'prr' or
                  args[0] == 'nonat' or
                  args[0] == 'debug'):
                if self.fd_rewrite != 0:
                    self.fd_section2 = self.fd_section2 + text + '\n'
    
    def cf_parse_defpool(self, args):
        """ parse a pool entry """
        if len(args) != 4:
            raise SyntaxError
        pool = CONFPOOL['default']
        dash = args[3].find('-')
        if dash < 0:
            raise SyntaxError
        min_ = args[3][0:dash]
        max_ = args[3][dash+1:]
        if args[2] == 'tcp':
            pool.tcpmin = min_
            pool.tcpmax = max_
        elif args[2] == 'udp':
            pool.udpmin = min_
            pool.udpmax = max_
    
    def cf_parse_pool(self, args):
        """ parse a pool entry """
        if (not (len(args) == 2 or len(args) == 4)):
            raise SyntaxError
        addr = args[1]
        pool = CONFPOOL.get(addr)
        if pool is None:
            pool = CONFPOOL['default'].copy()
            pool.addr = addr
            CONFPOOL[addr] = pool
        if len(args) > 2:
            dash = args[3].find('-')
            if dash < 0:
                raise SyntaxError
            min_ = args[3][0:dash]
            max_ = args[3][dash+1:]
            if args[2] == 'tcp':
                pool.tcpmin = min_
                pool.tcpmax = max_
            elif args[2] == 'udp':
                pool.udpmin = min_
                pool.udpmax = max_
    
    def cf_parse_nat(self, args):
        """parse a static NAT entry"""
        if len(args) != 7:
            raise SyntaxError
        remote = canonv6(args[1])
        if remote is None:
            raise SyntaxError(args[1])
        protocol = args[2]
        if (protocol != 'tcp') and (protocol != 'udp'):
            raise SyntaxError
        saddr = canonv4(args[3])
        if saddr is None:
            raise SyntaxError
        sport = canonport(args[4])
        if sport is None:
            raise SyntaxError
        naddr = canonv4(args[5])
        if naddr is None:
            raise SyntaxError
        nport = canonport(args[6])
        if naddr is None:
            raise SyntaxError
        parent = CONFTUN.gettunnel(remote, naddr)
        nat_entry = NatEntry(parent, protocol, saddr, sport, nport)
        parent.entries.append(nat_entry)
        CONFNAT[nat_entry.hashkey_src()] = nat_entry
        CONFNAT[nat_entry.hashkey_nat()] = nat_entry
    
    def cf_parse_tunnel(self, args):
        """parse a tunnel entry"""
        if len(args) == 2:
            return
        if len(args) < 2:
            raise SyntaxError
        remote = canonv6(args[1])
        if remote is None:
            raise SyntaxError
        naddr = canonv4(args[2])
        if naddr is None:
            raise SyntaxError
        CONFTUN.gettunnel(remote, naddr)
    
    def cf_parse_policies(self, args):
        """parse a mss, mtu, or toobig entry"""
        remote = canonv6(args[1])
        if remote is None:
            raise SyntaxError(" ".join(args))
        tunnel = CONFTUN.gettunnel(remote, None)
        # might return None if tunnel has not been declared yet,
        # or has been declared without a natted IPv4 address
        if tunnel is not None:
            tunnel.policies += " ".join(args) + '\n'

    def append(self, line):
        self.fd.write(line)
        self.fd.flush()

    def truncate(self):
        if self.fd_rewrite > 0:
            self.fd.seek(self.fd_rewrite)
            self.fd.truncate()
            self.fd.write(self.fd_marker + '\n')
            self.fd.flush()

    def rewrite(self):
        """rewrite config file"""
        if self.fd_rewrite == 0:
            # if there are no nat or tunnel commands, set the rewrite pointer
            # to the end of the file
            self.fd_rewrite = self.fd.tell()

        elif len(self.fd_section2) > 0:
            # consolidate the non-nat section 2 commands
            self.fd.seek(self.fd_rewrite)
            self.fd.truncate()
            self.fd.write('#### The following commands were ' +
                          'relocated here by aftrconf.py ####\n')
            self.fd.write(self.fd_section2)
            self.fd.write('\n')
            self.fd_rewrite = self.fd.tell()
            self.fd_section2 = ""

        self.fd.seek(self.fd_rewrite)
        self.fd.truncate()
        self.fd.write(self.fd_marker + '\n')
        CONFTUN.totext(self.fd)
        #self.fd.write("#### New commands since the last rewrite ####\n")
        self.fd.flush()

################ rpc processing ################

NSMAP = {'aftr': 'http://aftr.isc.org/mapping/1.0'}

class RpcBinding(object):
    """ NAT binding information in an rpc request """
    __slots__ = ['tunnel', 'protocol', 'saddr', 'sport', 'naddr', 'nport']

    def __init__(self, element):
        """ initializer """
        self.tunnel = None
        self.protocol = None
        self.saddr = None
        self.sport = None
        self.naddr = None
        self.nport = None
        
        for entry in element:
            if entry.tag == 'tunnel':
                self.tunnel = canonv6(entry.text)
            elif entry.tag == 'protocol':
                self.protocol = entry.text
            elif entry.tag == 'sourceAddress':
                self.saddr = canonv4(entry.text)
            elif entry.tag == 'sourcePort':
                self.sport = canonport(entry.text)
            elif entry.tag == 'nattedAddress':
                self.naddr = canonv4(entry.text)
            elif entry.tag == 'nattedPort':
                self.nport = canonport(entry.text)
            # XXX else ignore unrecognized subelements?

    def incomplete(self):
        return (self.tunnel is None or
                self.protocol is None or
                self.saddr is None or
                self.sport is None or
                self.naddr is None or
                self.nport is None)

    def incomplete_src(self):
        return (self.tunnel is None or
                self.protocol is None or
                self.saddr is None or
                self.sport is None)

    def incomplete_nat(self):
        return (self.protocol is None or
                self.naddr is None or
                self.nport is None)

    def tunnelonly(self):
        return (self.tunnel is not None and
                self.protocol is None and
                self.saddr is None and
                self.sport is None and
                self.naddr is None and
                self.nport is None)

    def lookup_by_src(self):
        """ find a matching NatEntry by internal address """
        return CONFNAT.get((self.tunnel, self.saddr,
                            self.sport, self.protocol))

    def lookup_by_nat(self):
        """ find a matching NatEntry by external address """
        return CONFNAT.get((self.naddr, self.nport, self.protocol))

class RpcParser:
    """routines to parse and process an XML <rpc> request"""

    def parse(self, request):
        """parse a message from the provisioning system"""
        parser = etree.XMLParser(remove_blank_text=True,
                                 remove_comments=True)
        try:
            root = etree.fromstring(request, parser)
            #tree = etree.parse(infile, parser)
        except etree.XMLSyntaxError:
            # XXX syslog?
            return "error parsing request"
    
        reply = etree.Element('rpc-reply', attrib=root.attrib, nsmap=NSMAP)

        # if the config file has been changed, ask aftr to reload it
        # if that's successful, rewrite the file
        reload_ = self.rpc_parse(root, reply)
        if reload_:
            resp = AFTRSOCK.askreload()
            if resp == 'OK':
                CONFFILE.rewrite()
            else:
                self.generror(reply, 'ERROR: ' + resp)

        return etree.tostring(reply, pretty_print=True)
    
    def rpc_parse(self, root, reply):
        """ parse <rpc> element """
        # returns an indication of whether the config file has changed,
        # and needs to be reloaded
        if root.tag != 'rpc':
            return self.generror(reply, 'ERROR: invalid document: ' + \
                                     'missing \'rpc\' element')
        reload_ = False
        for entry in root:
            if entry.tag == 'create':
                reload_ = reload_ or self.rpc_parse_create(entry, reply)
            elif entry.tag == 'delete':
                reload_ = reload_ or self.rpc_parse_delete(entry, reply)
            elif entry.tag == 'get':
                reload_ = reload_ or self.rpc_parse_get(entry, reply)
            elif entry.tag == 'flush':
                reload_ = reload_ or self.rpc_parse_flush(reply)
            else:
                self.generror(reply, 'ERROR: invalid operation %s' % entry.tag)
        return reload_
    
    def rpc_parse_create(self, entry, reply):
        """ parse <create> element """
        binding = RpcBinding(entry)
    
        # if any element is missing, ERROR
        # XXX <tunnel><nattedAddress> to create tunnel?
        if binding.incomplete():
            return self.generror(reply, 'ERROR: malformed create request')
    
        tunnel = CONFTUN.gettunnel(binding.tunnel, None)

        # if tunnel is not found, see if it was auto-created by a dynamic
        # nat binding, and if so, what its natted address is
        if tunnel is None:
            nsrc = AFTRSOCK.asktunnel(binding.tunnel, binding.naddr)
            if nsrc != binding.naddr:
                if canonv4(nsrc) is not None:
                    return self.generror(reply, 'ERROR: tunnel ' +
                                         binding.tunnel +
                                         ' is already bound to nattedAddress ' +
                                         nsrc)
                else:
                    return self.generror(reply, 'ERROR: ' + nsrc)
        else:
            # if nattedAddress does not match tunnel, ERROR
            if tunnel.addr != binding.naddr:
                return self.generror(reply, 'ERROR: tunnel ' +
                                     binding.tunnel +
                                     ' is already bound to nattedAddress ' +
                                     tunnel.addr)
    
        pool = CONFPOOL.get(binding.naddr)
        # if nattedAddress is not in the managed pool, ERROR
        if pool is None:
            return self.generror(reply, 'ERROR: external address not managed')
    
        # if nattedPort out of range, ERROR
        if binding.protocol == 'tcp':
            if (int(binding.nport) >= int(pool.tcpmin) and
                int(binding.nport) <= int(pool.tcpmax)):
                return self.generror(reply,
                                     'ERROR: external port out-of-range')
        elif binding.protocol == 'udp':
            if (int(binding.nport) >= int(pool.udpmin) and
                int(binding.nport) <= int(pool.udpmax)):
                return self.generror(reply,
                                     'ERROR: external port out-of-range')
        # if protocol not tcp or udp, ERROR
        else:
            return self.generror(reply, 'ERROR: malformed create request')
    
        binding_src = binding.lookup_by_src()
        binding_nat = binding.lookup_by_nat()
        # both = None: no existing nat binding
        # only binding_src set: 
        # only binding_nat set:
        # both set, not equal
        # both set, equal
        if binding_src is not None:
            if binding_src == binding_nat:
                # if full binding exists, return <ok>
                self.genok(reply)
                return False	# don't need to reload
            elif (binding_nat is not None and
                  binding_nat.parent != binding_src.parent):
                # if port bound to a different tunnel, ERROR
                return self.generror(reply,
                                     'ERROR: port assigned to ' +
                                     'another subscriber')
            else:
                # if src is bound to another port, delete then add
                #self.delnat(binding_src)
                #self.addnat(binding)
                #return self.genok(reply)
                #doesn't work today
                return self.generror(reply,
                                     'ERROR: nat change is not yet supported')
        elif binding_nat is not None:
            if binding_nat.parent.remote != binding.tunnel:
                # if port bound to a different tunnel, ERROR
                return self.generror(reply,
                                     'ERROR: port assigned to ' +
                                     'another subscriber')
            else:
                # if port is bound to a different src, delete then add
                #self.delnat(binding_nat)
                #self.addnat(binding)
                #return self.genok(reply)
                return self.generror(reply,
                                     'ERROR: nat change is not yet supported')
        else:
            self.addnat(binding)
            return self.genok(reply)
    
    def rpc_parse_delete(self, entry, reply):
        """ parse <delete> element """
        binding = RpcBinding(entry)
    
        if not binding.incomplete():
            nat_entry = binding.lookup_by_src()
            if nat_entry != binding.lookup_by_nat():
                return self.generror(reply, 'ERROR: no mapping found')
        elif not binding.incomplete_src():
            nat_entry = binding.lookup_by_src()
        elif not binding.incomplete_nat():
            nat_entry = binding.lookup_by_nat()
        elif binding.tunnelonly():
            # delete the whole tunnel
            tunnel = CONFTUN.gettunnel(binding.tunnel, None)
            if tunnel is not None:
                resp = AFTRSOCK.askdeletetunnel(tunnel.remote)
                if resp != 'OK':
                    return self.generror(reply, 'ERROR: ' + resp)
                while tunnel.entries:
                    nat_entry = tunnel.entries.pop()
                    del CONFNAT[nat_entry.hashkey_src()]
                    del CONFNAT[nat_entry.hashkey_nat()]
                    del nat_entry
                del CONFTUN[binding.tunnel]
                del tunnel
                #consistency()	# XXX debug
                self.genok(reply)
                return False	# don't need to reload
            else:
                return self.generror(reply, 'ERROR: no tunnel found')
        else:
            return self.generror(reply, 'ERROR: malformed delete request')
    
        if nat_entry is None:
            return self.generror(reply, 'ERROR: no mapping found')
    
        tunnel = nat_entry.parent
        self.delnat(nat_entry)
        return self.genok(reply)
    
    def rpc_parse_get(self, entry, reply):
        """ parse <get> element """
        remote = canonv6(entry.get('tunnel'))
        if remote is not None:
            # if the optional tunnel attribute is present, get that
            # tunnel's bindings
            tunnel = CONFTUN.gettunnel(remote, None)
            if tunnel is None:
                return self.generror(reply, 'ERROR: IPv6 address not found')
            msg = etree.SubElement(reply, 'conf')
            tunnel.toxml(msg)
        else:
            # else get the whole binding table
            msg = etree.SubElement(reply, 'conf')
            CONFTUN.toxml(msg)
        return False
    
    def rpc_parse_flush(self, reply):
        """ parse <flush> element """
        for value in CONFTUN.itervalues():
            AFTRSOCK.askdeletetunnel(value.remote)
        CONFTUN.clear()
        CONFNAT.clear()
        # remove all nat and tunnel entries from the config file
        CONFFILE.truncate()
        self.genok(reply)
        return False	# don't need to reload
    
    def addnat(self, binding):
        parent = CONFTUN.gettunnel(binding.tunnel, binding.naddr)
        nat_entry = NatEntry(parent, binding.protocol, binding.saddr,
                                     binding.sport, binding.nport)
        parent.entries.append(nat_entry)
        CONFNAT[nat_entry.hashkey_src()] = nat_entry
        CONFNAT[nat_entry.hashkey_nat()] = nat_entry
        CONFFILE.append("nat %s %s %s %s %s %s\n" %
                        (binding.tunnel, binding.protocol, binding.saddr,
                         binding.sport, binding.naddr, binding.nport))
        #consistency()	# XXX debug
    
    def delnat(self, nat_entry):
        # remove the binding from the config file, and it will
        # automatically be deleted on the next reload
        parent = nat_entry.parent
        parent.entries.remove(nat_entry)
        del CONFNAT[nat_entry.hashkey_src()]
        del CONFNAT[nat_entry.hashkey_nat()]
        del nat_entry
        #consistency()	# XXX debug
    
    def generror(self, reply, error):
        msg = etree.SubElement(reply, 'rpc-error')
        msg.text = error
        return False	# don't need to reload
    
    def genok(self, reply):
        etree.SubElement(reply, 'ok')
        return True	# need to reload

def consistency():
    """consistency check, for debugging"""
    for key, value in CONFTUN.items():
        if key is not value.remote:
            print 'CONFTUN: key', key, 'value', value.remote
        for entry in value.entries:
            xchk = CONFNAT.get((value.remote, entry.saddr,
                                entry.sport, entry.protocol))
            if xchk is None:
                print 'hashkey_src: not found:', \
                    value.remote, entry.saddr, \
                    entry.sport, entry.protocol
            elif xchk is not entry:
                print 'hashkey_src: match error:'
                xchk.totext(sys.stdout)
                entry.totext(sys.stdout)
            xchk = CONFNAT.get((value.addr, entry.nport, entry.protocol))
            if xchk is None:
                print 'hashkey_nat: not found:', \
                    value.addr, entry.nport, entry.protocol
            elif xchk is not entry:
                print 'hashkey_nat: match error:'
                xchk.totext(sys.stdout)
                entry.totext(sys.stdout)

    for key, value in CONFNAT.items():
        if len(key) == 4:
            # source key
            keystr = key[0] + ',' + key[1] + ',' + key[2] + ',' + key[3]
            tunnel = value.parent
            for entry in tunnel.entries:
                if entry is value:
                    break
            if entry is not value:
                print 'tunnel: not found:', keystr
            xchk = CONFNAT.get(value.hashkey_nat())
            if xchk is None:
                print 'hashkey_nat: not found:', keystr
            elif xchk is not value:
                print 'hashkey_nat: match error:', keystr
        else:
            # nat key
            keystr = key[0] + ',' + key[1] + ',' + key[2]
            tunnel = value.parent
            for entry in tunnel.entries:
                if entry is value:
                    break
            if entry is not value:
                print 'tunnel: not found:', keystr
            xchk = CONFNAT.get(value.hashkey_src())
            if xchk is None:
                print 'hashkey_src: not found:', keystr
            elif xchk is not value:
                print 'hashkey_src: match error:', keystr

################ AFTR control connection ################

def debug(msg):
    """debug helper"""
    #print msg
    pass

class AftrSock:
    """open a control connection to the running aftr"""

    def __init__(self):
        """ initializer """
        self.sock = socket.socket(socket.AF_INET,
                                  socket.SOCK_STREAM,
                                  socket.IPPROTO_TCP)
        try:
            self.sock.connect(('127.0.0.1', 1015))
        except socket.error as err:
            print "ERROR:", err[1]
            sys.exit(1)
        self.sock.sendall('session log off\n')
        self.sock.sendall('session config on\n')
        self.sock.sendall('session name aftrconf\n')
        self.transid = 0
        self.nextlinebuf = ''

    def ask(self, cmd):
        try:
            self.sock.sendall(cmd + '\n')
        except:
            raise
        debug('sent to AFTR: %s' % cmd)

    def getnextline(self):
        """get next line from AFTR daemon"""
        if len(self.nextlinebuf) == 0:
            self.nextlinebuf = self.sock.recv(1024)
        if not '\n' in self.nextlinebuf:
            ret = self.nextlinebuf
            self.nextlinebuf = ''
            return ret
        i = self.nextlinebuf.index('\n') + 1
        ret = self.nextlinebuf[:i]
        self.nextlinebuf = self.nextlinebuf[i:]
        if ret[-1:] == '\n':
            ret = ret[:-1]
        debug('received from AFTR: %s' % ret)
        return ret
    
    def expect(self, text):
        """expect a text from AFTR daemon"""
        self.transid += 1
        echo = 'echo ' + str(self.transid)
        try:
            self.sock.send(echo + '\n')
        except socket.error as err:
            return (False, err[1])
        i = 5
        got = ''
        while i != 0:
            i -= 1
            prev = got
            got = self.getnextline()
            if len(got) == 0:
                continue
            if got[:len(text)] == text:
                debug("got expected '%s' from AFTR" % text)
                return (True, got[len(text):])
            if got == 'command failed':
                debug('got failure from AFTR')
                return (False, prev)
            if got == echo:
                debug('got echo from AFTR')
                return (False, '')
        return (False, '')

    def asktunnel(self, tunnel, naddr):
        """ask aftr for tunnel natted address"""
        try:
            self.ask('try tunnel ' + tunnel + ' ' + naddr)
        except socket.error as err:
            return err[1]
        (ret, nsrc) = self.expect('tunnel ' + tunnel + ' ' + naddr)
        if ret:
            return naddr
        try:
            self.ask('try tunnel ' + tunnel)
        except socket.error as err:
            return err[1]
        (ret, nsrc) = self.expect('tunnel ' + tunnel + ' ')
        if ret:
            debug('tunnel %s %s' % (tunnel, nsrc))
        else:
            debug('tunnel %s failed?' % tunnel)
        return nsrc

    def askdeletetunnel(self, tunnel):
        """ask aftr to delete a tunnel"""
        try:
            self.ask('delete tunnel ' + tunnel)
        except socket.error as err:
            return err[1]
        (ret, resp) = self.expect('')
        if (ret or resp == ''):
            return 'OK'
        else:
            return resp

    def askreload(self):
        """ask aftr to reload config file"""
        try:
            self.sock.sendall('reload\n')
        except socket.error as err:
            return err[1]
        # allow the aftr to start the reload before sending the next command
        # (avoid putting 'echo' in the same command buffer)
        time.sleep(1)
        self.transid += 1
        echo = 'echo ' + str(self.transid)
        self.sock.sendall(echo + '\n')
        # block for the first line of the response
        response = self.sock.recv(1024)
        # read until we get the echo
        while echo not in response:
            response = self.sock.recv(1024)
        if 'in progress' in response:
            # transient error, try again
            return self.askreload()
        if 'reload failed' in response:
            # actual error is the line before 'reload failed'
            lines = response.splitlines
            for i in len(lines):
                if 'reload failed' in lines[i]:
                    return lines[i-1]
        return 'OK'

AFTRSOCK = AftrSock()

################ main ################

sockport = 4146
httpport = 4148

authpeer = ""

if transport == "socket":
    import SocketServer
    class sockhandler(SocketServer.BaseRequestHandler):
        def handle(self):
            if authpeer != "":
                # acl check
                peer = self.request.getpeername()
                if (peer[0] != authpeer):
                    # XXX send error message
                    return

            # Most requests should be under 1k, but we have to be able
            # to accomodate large batched requests. So keep reading until
            # we see the </rpc> end tag. But also build in a timeout so
            # we don't wedge on a malformed request.
            self.request.settimeout(3.0)
            request = self.request.recv(1024)
            while "</rpc>" not in request[len(request) - 10:]:
                try:
                    buf = self.request.recv(1024)
                    request += buf
                except socket.timeout:
                    break
            rpcparser = RpcParser()
            reply = rpcparser.parse(request)
            # send an xml header for the hell of it
            self.request.sendall("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
            self.request.sendall(reply)

elif transport == "http":
    import BaseHTTPServer
    class httphandler(BaseHTTPServer.BaseHTTPRequestHandler):
        def do_POST(self):
            if (authpeer != "" and self.client_address[0] != authpeer):
                # acl check
                self.send_error(403, 'Forbidden')
                return

            request = self.rfile.read()
            rpcparser = RpcParser()
            reply = rpcparser.parse(request)

            self.send_response(200)
            self.send_header("Content-type","application/xml")
            self.end_headers()
            self.wfile.write(reply)

        def do_GET(self):
            # just in case
            return self.do_POST()

else:
    # debug: read from a file, or pipe from client.py
    def handler():
        rpcparser = RpcParser()
        reply = rpcparser.parse(request)

def main(args):
    """main"""
    laddr = ""
    port = 0
    conf = "aftr.conf"
    try:
        opts, argv = getopt.getopt(args[1:], 'p:l:r:c:')
    except getopt.GetoptError:
        print 'usage:', args[0], \
            '[-l listening addr] [-p listening port]', \
            '[-r remote addr] [-c config file]'
        raise
    for opt, arg in opts:
        if opt == '-p':
            port = int(arg)
            continue
        if opt == '-l':
            laddr = arg
            continue
        elif opt == '-r':
            global authpeer
            authpeer = arg
            continue
        elif opt == '-c':
            conf = arg
            continue
    if len(argv) != 0:
        print args[0] + ':', 'extra arguments:', argv[0], ', ...'
        sys.exit(1)

    global CONFFILE
    CONFFILE = ConfigFile(conf)

    if transport == "socket":
        if port == 0:
            port = sockport
        rpcserver = SocketServer.TCPServer((laddr, port), sockhandler)
        try:
            rpcserver.serve_forever()
        except KeyboardInterrupt:
            return

    elif transport == "http":
        if port == 0:
            port = httpport
        rpcserver = BaseHTTPServer.HTTPServer((laddr, port), httphandler)
        try:
            rpcserver.serve_forever()
        except KeyboardInterrupt:
            return

    else:
        # debug: read one request from a file, or pipe from client.py
        handler()

main(sys.argv)
