import sys, Pyro.core, Pyro.util, Pyro.naming
from threading import *
from socket import *

default_ns_host = 'ece.olin.edu'

class RemoteObject(Pyro.core.ObjBase):
    def __init__(self, name=None, ns_host=default_ns_host):
        Pyro.core.ObjBase.__init__(self)
        if name == None:
            name = str(self)

        # find the name server
        self.ns_host = ns_host
        self.ns = Pyro.naming.NameServerLocator().getNS(ns_host)

        # create the daemon
        addr = self.get_ip_addr()
        self.demon = Pyro.core.Daemon(host=addr)
        self.demon.useNameServer(self.ns)

        # instantiate the object and advertise it
        self.uri = self.demon.connect(self, name)


    def requestLoop(self):
        # run the request loop until an exception occurs
        try:
            self.demon.requestLoop()
        except:
            print 'Shutting down the server...'
            self.cleanup()
            if sys.exc_type != KeyboardInterrupt:
                raise sys.exc_type, sys.exc_value

    def cleanup(self):
        # remove the name from the name server
        self.demon.disconnect(self)
        self.demon.shutdown()

    def threadLoop(self):
        # run the request loop in a separate thread
        self.thread = Thread(target=self.stoppableLoop)
        self.thread.start()
        
    def stoppableLoop(self):
        # run handleRequests until another thread clears self.running
        self.running = 1
        try:
            while self.running:
                self.demon.handleRequests(0.1)
        finally:
            self.cleanup()

    def stopLoop(self):
        # if threadLoop is running, stop it
        self.running = 0

    def join(self):
        # wait for the threadLoop to complete
        self.thread.join()
        
    def get_ip_addr(self):
        # get the real IP address of this machine
        csock = socket(AF_INET, SOCK_STREAM)
        csock.connect((self.ns_host, 80))
        (addr, port) = csock.getsockname()
        csock.close()
        return addr

def get_remote_object(name, ns_host=default_ns_host):
# look up a name on a Pyro name server, and create a proxy for
# the remote object
    ns = get_name_server(ns_host)
    uri = ns.resolve(name)
    return Pyro.core.getProxyForURI(uri)

def get_name_server(ns_host=default_ns_host):
# contact the name server and get the names of the remote objects
# in the given group
    return Pyro.naming.NameServerLocator().getNS(ns_host)

class NameServer:
    def __init__(self, ns_host=default_ns_host):
        self.ns_host = ns_host
        self.ns = Pyro.naming.NameServerLocator().getNS(ns_host)

    def get_remote_object(name, ns_host=default_ns_host):
    # look up a name on a Pyro name server, and create a proxy for
    # the remote object
        uri = self.ns.resolve(name)
        return Pyro.core.getProxyForURI(uri)

    def name_server_query(self, name, group=None):
    # check whether the given name is registered in the given group.
    # return 1 if the name is a remote object, 0 if it is a group,
    # and -1 if it doesn't exist.
        t = self.ns.list(group)
        d = dict(t)
        try:
            return d[name]
        except KeyError:
            return -1

    def create_group(self, name):
    # create a group with the given name
        self.ns.createGroup(name)
    
    def get_remote_object_list(self, prefix='', group=None):
    # return a list of the remote object on the given name server,
    # in the given group, that start with the given prefix
        t = self.ns.list(group)
        u = [s for (s, n) in t if n==1 and s.startswith(prefix)]
        return u

    def clear(self, prefix='', group=None):
        print group
        t = self.ns.list(group)
        for (s, n) in t:
            if not s.startswith(prefix): continue
            if n==1:
                self.ns.unregister(n)
    
def main(script, name='remote_object', group='test', *args):
    if name_server_query(group) == -1:
        print 'Making group %s...' % group
        create_group(group)

    full_name = '%s.%s' % (group, name)
    print 'Starting %s...' % full_name

    server = RemoteObject(full_name)

    ns = get_name_server()
    print ns.list(group)

    print group, name_server_query(group)
    print full_name, name_server_query(name, group)
    print full_name, name_server_query(full_name)
    
    print get_remote_object_list('a')
    print group, get_remote_object_list(group=group)
        
    server.requestLoop()

if __name__ == '__main__':
    main(*sys.argv)
