#!/usr/bin/python
# firewaller: OpenVPN learn-address script to install iptables rules
# Version 0.1
# Copyright 2006, Dale Sedivec
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
# USA

# The RULES_MAP specifies a mapping between a common name (CN) and a
# rule set name (see RULES_DIRECTORY).  These should be one per line,
# CN and rule set name separated by some whitespace.
#
# OpenVPN common names only allow alnum, underscore, hyphen, period,
# and '@'.  All other characters should be converted to underscores.
# This format is how the CN must be specified in the map.
RULES_MAP = "/etc/openvpn/scripts/firewaller-map"

# The rule set name specified in the RULES_MAP should refer to a file
# in this directory.  Each rule set should be a list of iptables
# commands, without the "iptables -[AD] <chain name>" bit in front.
# I.e., "-p tcp --dport ftp -j DROP" would be a valid line in a rule
# file.
RULES_DIRECTORY = "/etc/openvpn/scripts/firewaller-rules"

# Pre-existing chain that we can insert our "dispatch rules" into.  A
# dispatch rule dispatches a particular source address (MAC or IP
# depending on whether you're bridging or routing, respectively) to
# the chain created by its rule set.
CHAIN_NAME = "openvpn-firewaller"

# What command to use for iptables.  Remember that OpenVPN might not
# be running as root, depending on your configuration.
IPTABLES = "sudo /sbin/iptables"

# IPTABLES_CLIENT_MATCH is a set of iptables arguments that will be
# used to match an incoming OpenVPN client on a particular device with
# a particular address.  You can refer to %(device)s and %(address)s.
#
# This is for a bridged network.
IPTABLES_CLIENT_MATCH = ("-m mac --mac-source %(address)s")
# This is for a routed network (untested).
#IPTABLES_CLIENT_MATCH = "-s %(address)s"

# firewaller creates a chain for each client address.  That chain's
# name is either the MAC (bridged) or IP address (routed) of the
# client.  If a chain with the same name already exists,
# REPLACE_EXISTING_CHAIN determines what firewaller will (attempt to)
# do.  If REPLACE_EXISTING_CHAIN is True, firewaller will try to
# delete and re-add the chain.  (Firewaller will also attempt to
# delete any rules in the CHAIN_NAME chain that refer to this chain.)
# If False, firewaller will exit with error.
#
# You probably want this on, for the off chance that some old chains
# start hanging around in iptables (probably firewaller's fault).
REPLACE_EXISTING_CHAIN = True

######################################################################

# A potential problem: even if every client uses the same rule set,
# they all get their own chain (with identical rules), thus you might
# have way too many chains.  It might be better to name chains after
# their rule sets, then have the dispatch rule dispatch to the rule
# set's chain.  You could attempt to delete the rule set chain each
# time an address is deleted; it won't succeed until all addresses
# using that rule set have been deleted.
#
# Of course, if you get that many chains, you'll probably also get a
# whole bunch of dispatch rules, which will be a huge performance hit
# (as opposed to just memory usage hit, or possibly hitting some upper
# limit in iptables, as with the previous situation).
#
# Also: some kind of iptables transactional API might be nice, to
# rollback rules that were already executed as part of the same
# transaction when some iptables command fails.  I have had the
# occasional bad rule hang around; however,
# REPLACE_EXISTING_CHAIN=True probably removes the need for this.
#
# Potential things to add:
# * Logging, both errors and success messages.
# * Optionally run a script, rather than just reading iptables rules
#   from a file.

import sys
import os

class IptablesError (Exception):
    pass

def iptables(iptablesArguments, raiseOnError=True):
    """Run iptables with the given arguments.

    If no error occurs, returns True.  If an error occurs and
    raiseOnError is True (the default), raise an IptablesError
    exception.  If an error occurs and raiseOnError is False, returns
    False.

    """
    command = "%s %s" % (IPTABLES, iptablesArguments)
    status = os.system(command)
    if status == -1:
        raise IptablesError("failed to invoke iptables (%s)" % (command,))
    status = os.WEXITSTATUS(status)
    if raiseOnError and (status != 0):
        raise IptablesError("iptables exited with status %d (%s)" % (status,
                                                                     command))
    return status == 0

def newTable(name): iptables("-N " + name)
def deleteTable(name, raiseOnError=True):
    iptables("-F " + name, raiseOnError)
    iptables("-X " + name, raiseOnError)

def makeDispatchRuleMethod(action, raiseOnError=True):
    def dispatchRuleMethod(device, address):
        return iptables("-%s %s %s -j %s" % (action,
                                             CHAIN_NAME,
                                             makeDispatchMatch(device,
                                                               address),
                                             address),
                        raiseOnError)
    return dispatchRuleMethod
addDispatchRule = makeDispatchRuleMethod("A")
deleteDispatchRule = makeDispatchRuleMethod("D", False)

def makeDispatchMatch(device, address):
    values = {"device": device, "address": address}
    return IPTABLES_CLIENT_MATCH % values

def installRules(address, ruleSetName):
    rules = open(os.path.join(RULES_DIRECTORY, ruleSetName), "r")
    for rule in rules:
        rule = rule.strip()
        iptables("-A %s %s" % (address, rule))

def readRulesMap():
    rulesMap = open(RULES_MAP, "r")
    for mapping in rulesMap:
        yield mapping.strip().split(None, 2)

def lookupCommonName(cn):
    for readCN, ruleSetName in readRulesMap():
        if readCN == cn:
            return ruleSetName
    return None

def addAddress(device, address, cn):
    ruleSetName = lookupCommonName(cn)
    if ruleSetName:
        if REPLACE_EXISTING_CHAIN:
            deleteDispatchRule(device, address)
            deleteTable(address, False)
        newTable(address)
        installRules(address, ruleSetName)
        addDispatchRule(device, address)

def deleteAddress(device, address):
    if deleteDispatchRule(device, address):
        deleteTable(address)

def updateAddress(device, address, cn):
    deleteAddress(device, address)
    addAddress(device, cn, address)

def main():
    device = os.environ["dev"]
    operation = sys.argv[1]
    handlers = {"add": (3, addAddress),
                "update": (3, updateAddress),
                "delete": (2, deleteAddress)}
    numArgs, handler = handlers[operation]
    handler(device, *sys.argv[2:numArgs + 1])
    # Any errors should be signalled by an unhandled exception, and
    # hopefully Python exits non-zero in that case.
    sys.exit(0)

if __name__ == "__main__":
    main()
