threaded.py

import unittest as ut
import socket
import threading
import struct
import errno
import time
import random

import dpkt

from dprobelib import sniffer

SNIFFER_INTERFACE = "lo"
SERVER_ADDRESS = "127.0.0.1"
# Maximum time a blocking operation should wait (in seconds).
MAX_WAIT = 3.0

class CooperativeThread (threading.Thread):
    """A Thread that can notify when it has started, and can be stopped."""

    def __init__(self, *args, **kwargs):
        threading.Thread.__init__(self, *args, **kwargs)
        self.__started = threading.Event()
        self.__shouldStop = threading.Event()

    def notifyStarted(self):
        '''Should be called by the thread when it has "started."'''
        self.__started.set()

    def shouldStop(self):
        """Returns True if someone has requested that this thread terminate."""
        return self.__shouldStop.isSet()

    def startAndWait(self, timeout=None):
        """Start the thread and wait for it to signal that it has started."""
        self.start()
        self.__started.wait(timeout)
        return self.__started.isSet()

    def _closeResources(self):
        """Close resources that this thread might be blocking on.

        This is intended to be overridden by a subclass, if necessary.
        stop() will call this after ensuring that shouldStop() is
        True.

        """
        pass

    def stop(self, timeout=None):
        """Request that the thread stop and wait for it to join."""
        self.__shouldStop.set()
        if self.isAlive():
            try:
                self._closeResources()
            except:
                pass
            self.join(timeout)
        return self.isAlive()

    # This method necessary to work around bug #1171023.
    def join(self, timeout=None):
        try:
            threading.Thread.join(self, timeout)
        finally:
            try:
                self._Thread__block.release()
            except:
                pass

# XXX It's possible this doesn't deserve to be a subclass, since it
# seems to only define a useful method, as opposed to really being a
# type of CooperativeThread.  It is convenient to call
# self.shouldStop() though.
class CooperativeIOThread (CooperativeThread):
    # XXX Wish I had a better name for this routine.
    def callIOSafe(self, callable, *args, **kwargs):
        """Run callable and ignore errors caused by resource deprivation.

        An error caused by resource deprivation is defined as a
        particular type of socket exception when shouldStop() == True.

        """
        try:
            callable(*args, **kwargs)
        except socket.error, exception:
            if self.__isClosedSocketError(exception) and self.shouldStop():
                return
            else:
                raise

    def __isClosedSocketError(self, exception):
        return exception.args[0] == errno.EBADF

class TrafficGenerator (CooperativeIOThread):
    SEND_FREQUENCY = 0.1

    def __init__(self, data, *args, **kwargs):
        CooperativeIOThread.__init__(self, *args, **kwargs)
        self.data = data
        self.__makeServerSocket()

    def __makeServerSocket(self):
        self.__server = self.__makeUDPSocket()
        self.__server.bind((SERVER_ADDRESS, socket.INADDR_ANY))
        self.host, self.port = self.__server.getsockname()

    def __makeUDPSocket(self):
        return socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

    def _closeResources(self):
        self.__server.close()
        self.__client.close()

    def run(self):
        self.__makeClientSocket()
        self.notifyStarted()
        while not self.shouldStop():
            time.sleep(self.SEND_FREQUENCY)
            self.callIOSafe(self.__client.send, self.data)

    def __makeClientSocket(self):
        self.__client = self.__makeUDPSocket()
        self.__client.connect((self.host, self.port))

class SnifferThread (CooperativeIOThread):
    def __init__(self, destinationHost, destinationPort, data, *args,
                 **kwargs):
        """If data is None, no data matching will be performed."""
        CooperativeIOThread.__init__(self, *args, **kwargs)
        self.__host = socket.inet_aton(destinationHost)
        self.__port = destinationPort
        self.__data = data
        self.lastPacket = None
        self.sawTestPacket = False

    def run(self):
        self.__sniffer = sniffer.sniff(SNIFFER_INTERFACE)
        def snifferLoop():
            for outgoing, packet in self.__sniffer:
                self.lastPacket = packet
                if self.__isTestPacket(packet):
                    self.sawTestPacket = True
                    break
                elif self.shouldStop():
                    break
        self.notifyStarted()
        self.callIOSafe(snifferLoop)

    def __isTestPacket(self, packet):
        ipPacket = dpkt.ip.IP(packet)
        if ipPacket.dst == self.__host and ipPacket.p == socket.IPPROTO_UDP \
               and ipPacket.data.dport == self.__port:
            if self.__data is not None:
                return ipPacket.data.data == self.__data
            else:
                return True
        else:
            return False

    def _closeResources(self):
        self.__sniffer.close()

class TestSniffer (ut.TestCase):
    def testSniffPacket(self):
        snifferThread = self.__runSniffer()
        self.__assertSnifferSawTestPacket(snifferThread)

    def __assertSnifferSawTestPacket(self, snifferThread):
        self.assertTrue(snifferThread.sawTestPacket,
                        "sniffer never saw test packet")

    def __runSniffer(self, dataLength=None, snifferChecksData=True):
        data = self.__makeRandomData(dataLength)
        snifferThread, trafficThread \
                       = self.__makeSnifferThreads(data,snifferChecksData)
        try:
            self.__assertThreadStarted(snifferThread, "sniffer")
            self.__assertThreadStarted(trafficThread, "traffic generator")
            snifferThread.join(MAX_WAIT)
        finally:
            snifferThread.stop(MAX_WAIT)
            trafficThread.stop(MAX_WAIT)
        return snifferThread

    def __makeSnifferThreads(self, data, snifferChecksData):
        trafficThread = TrafficGenerator(data=data)
        if snifferChecksData:
            snifferData = data
        else:
            snifferData = None
        snifferThread = SnifferThread(destinationHost=trafficThread.host,
                                      destinationPort=trafficThread.port,
                                      data=snifferData)
        return snifferThread, trafficThread

    def __assertThreadStarted(self, thread, threadName):
        self.assertTrue(thread.startAndWait(MAX_WAIT),
                        "%s never started" % (threadName,))

    def __makeRandomData(self, length=None):
        if length == None:
            length = 64
        return "".join([ struct.pack("B", random.randint(0, 255))
                         for counter in range(length) ])

    def testCorrectCaptureLength(self):
        dataLength = sniffer.CAPTURE_LENGTH * 2
        snifferThread = self.__runSniffer(dataLength, snifferChecksData=False)
        self.__assertSnifferSawTestPacket(snifferThread)
        self.assertEqual(sniffer.CAPTURE_LENGTH, len(snifferThread.lastPacket))

Generated by GNU enscript 1.6.1.