"""
lexx.py
Allen B. Downey

Implementation of deterministic and nondeterministic
finite automata.

Page references below are from Sipser, "Introduction to the
Theory of Computation: Second Edition"

"""

def id_generator():
    """this function is a generator that yields increasing
    numbers to be used as ids"""
    i = 1
    while True:
        yield i
        i += 1


class State:
    """represent a state in an FA"""

    # initialize the id_generator
    ids = id_generator()

    def __init__(self):
        self.id = State.ids.next()
        self.name = 'State %d' % (self.id,)

    def __str__(self):
        return self.name


class DFA:
    """deterministic finite automaton"""
    
    def __init__(self, states, alphabet, trans, start, accept):
        self.states = states
        self.alphabet = alphabet
        self.trans = trans
        self.start = start
        self.accept = accept

        # the pit of despair is not part of the formal
        # definition of a DFA, but it is useful for my
        # implementation of append
        self.pit = None

    def dump(self):
        """print the states and transitions:
        the start state is marked with an S, accept states with a *"""
        for state in self.states:
            if state is self.start:
                print 'S',
            if state in self.accept:
                print '*',
            print state

        for (state, symbol), dest in self.trans.iteritems():
            print state.name + ', ' + symbol + ' -> ' + dest.name

    def copy(self, constructor):
        """make a copy of this DFA, returning either a DFA or NFA,
        depending on constructor.  The sets (states and accept)
        and dictionary (trans) have to be copied, but the states
        themselves can be shared (because they are immutable)
        and should be shared (so that different trans functions
        can refer to the same states).
        """
        states = self.states.copy()
        alphabet = self.alphabet
        trans = self.trans.copy()
        start = self.start
        accept = self.accept.copy()
        new = constructor(states, alphabet, trans, start, accept)
        new.pit = self.pit
        return new

    def append(self, next):
        """return a new DFA that accepts the same language as the
        original with each string extended by one symbol.
        """

        # make a copy with no accept states
        new = self.copy(DFA)
        new.accept = set()

        # create a new state and make it an accept state
        new_state = State()
        new.states.add(new_state)
        new.accept.add(new_state)

        # create transitions from the old accept states
        # to the new state
        for old_accept in self.accept:
            for symbol in self.alphabet:

                # and from the new state to the pit
                new.trans[new_state, symbol] = new.pit

                # if the next symbol is correct, go to the new accept state;
                # otherwise, go to the pit of despair!
                if symbol == next:
                    new.trans[old_accept, symbol] = new_state
                else:
                    new.trans[old_accept, symbol] = new.pit
                    
        return new

    def union(self, other):
        """return a new DFA that accepts the union of the languages
        accepted by self and other.  See page 46 of Sipser.
        """

        # the set of states is the cross product of states from self
        # and other.  (states) is a dictionary that maps from
        # (r1, r2) -> q, where r1 is an element of self.states,
        # r2 is an element of other.states, and q is a new state.
        states = {}
        for r1 in self.states:
            for r2 in other.states:
                states[r1, r2] = State()

        # compute the union of the the alphabets
        alphabet = self.alphabet.union(other.alphabet)

        # form the new transition function
        trans = dict()
        for r1 in self.states:
            for r2 in other.states:
                for a in alphabet:
                    state = states[r1, r2]
                    dest = states[self.trans[r1, a], other.trans[r2, a]]
                    trans[state, a] = dest


        # start in the state that represents (self.start, other.start)
        start = states[self.start, other.start]

        # form the set of accept states
        accept = set()
        for r1 in self.states:
            for r2 in other.states:
                if r1 in self.accept or r2 in other.accept:
                    state = states[r1, r2]
                    accept.add(state)

        # build the DFA
        new = DFA(states.values(), alphabet, trans, start, accept)
        return new

    def process(self, string):
        """read a string and print 'accept' if the FA accepts the
        string and 'reject' otherwise
        """

        # start in the start state
        state = self.start
        
        for symbol in string:
            # compute the next state
            state = self.trans[state, symbol]

        # if you end in an accept state, accept.
        if state in self.accept:
            print 'accept'
        else:
            print 'reject'

        

class NFA(DFA):
    """non-deterministic finite automaton"""

    def __init__(self, *args):
        """an NFA is just a DFA with null transitions;
        (nulltrans) is a mapping from states to a list of states
        that can be reached by a null transition.
        """
        DFA.__init__(self, *args)
        self.nulltrans = dict()
    
    def copy(self, constructor=None):
        """make a copy of this NFA, returning an NFA
        (and ignoring constructor)"""
        new = DFA.copy(self, NFA)
        new.nulltrans = self.nulltrans.copy()

    def add_nulltrans(self, src, dest):
        """add a null transition from src to dest"""
        self.nulltrans.setdefault(src, []).append(dest)

    def process(self, string):
        """read a string and print 'accept' if the FA accepts the
        string and 'reject' otherwise
        """

        # start with the null closure of the start state
        states = self.null_closure(self.start)

        for symbol in string:
            # start with an empty set and accumulate all the
            # states that can be reached
            new_states = set()

            # for each state in the current set
            for state in states:
                try:
                    # find the next state for this symbol
                    next = self.trans[state, symbol]

                    # find the null closure of that state
                    closure = self.null_closure(next)

                    # and add it into the accumulator
                    new_states = new_states.union(closure)

                except KeyError:
                    # an NFA can have states that don't have
                    # transitions for all symbols
                    pass

            # the new set of states is the set of states we can reach
            states = new_states

        # if any of the states we can reach is an accept state, accept
        for state in states:
            if state in self.accept:
                print 'accept'
                return
        print 'reject'

    def null_closure(self, state):
        """compute the null closure of state, which is the set of
        states (including state) that can be reached from state
        following only null transitions.
        """

        # start with a set that includes state and a work
        # queue with a single element
        queue = [state]
        closure = set(queue)

        # as long as there are new states in the queue...
        while queue:

            # pop a state from the queue
            src = queue.pop(0)

            # if it doesn't have any null transitions, move on
            if src not in self.nulltrans:
                continue

            # for each of its null transitions
            for dest in self.nulltrans[src]:

                # if we haven't already seen the destination
                if dest not in closure:

                    # add it to the closure and to the queue
                    closure.add(dest)
                    queue.append(dest)

        return closure
                
    def star(self):
        """return a new NFA that represents self*; for example,
        if self accepts ab, then self* should accept
        ab, abab, ababab, etc., plus the null string.
        See page 62 of Sipser.
        """
        new = self.copy(NFA)
        new.start = State()
        new.states.add(new.start)
        new.accept.add(new.start)

        new.nulltrans = self.nulltrans
        new.add_nulltrans(new.start, self.start)        
        for state in self.accept:
            new.add_nulltrans(state, new.start)        
        return new

    def union(self, other):
        """return a new NFA that represents self U other;
        that is, it should accept any string accepted by
        self or other, and reject all others.
        See page 59 of Sipser.
        """
        # form the union of the states
        states = self.states.union(other.states)

        # compute the union of the the alphabets
        alphabet = self.alphabet.union(other.alphabet)

        # form the union of the transition functions
        trans = self.trans.copy()
        trans.update(other.trans)

        # create a new start state
        start = State()
        states.add(start)

        # form the union of the accept states
        accept = self.accept.union(other.accept)

        # create the NFA
        new = NFA(states, alphabet, trans, start, accept)

        # add null transitions from the new start to the
        # start states of self and other
        new.add_nulltrans(new.start, self.start)
        new.add_nulltrans(new.start, other.start)
        
        return new        

    def concat(self, other):
        """return a new NFA that represents self o other;
        that is, it should accept any string accepted by
        self followed by a string accepted by other.
        See page 61 of Sipser.
        """

        # form the union of the states
        states = self.states.union(other.states)

        # compute the union of the the alphabets
        alphabet = self.alphabet.union(other.alphabet)

        # form the union of the transition functions
        trans = self.trans.copy()
        trans.update(other.trans)

        # the new start state is self.start
        start = self.start

        # the accept states are other.accept
        accept = other.accept.copy()

        # create the NFA
        new = NFA(states, alphabet, trans, start, accept)

        # add null transitions from the accept states of
        # self to the start state of other
        for state in self.accept:
            new.add_nulltrans(state, other.start)
            
        return new


def make_dfa(alphabet, regexp):
    """make a DFA that recognizes (regexp), for the minimal
    regular expression language, which includes strings of
    symbols, and no special character.
    """
    
    # start with a DFA that accepts the null string
    start = State()
    pit = State()
    states = set([start, pit])
    accept = set([start])
    
    # all transitions lead to the pit!
    trans = dict()
    for symbol in alphabet:
        trans[start, symbol] = pit
        trans[pit, symbol] = pit

    # build the DFA
    dfa = DFA(states, alphabet, trans, start, accept)
    dfa.pit = pit

    # append the symbols from regexp one at a time
    for char in regexp:
        dfa = dfa.append(char)

    return dfa
    

def main():
    alphabet = set(['a', 'b'])
    ab = make_dfa(alphabet, 'ab')
    ba = make_dfa(alphabet, 'ba')

    #dfa = ab.union(ba)

    nfa = ab.copy(NFA)
    #nfa = nfa.star()
    #nfa = nfa.union(ba)
    nfa = nfa.concat(ba)
    nfa.dump()
    closure = nfa.null_closure(nfa.start)

    nfa.process('')
    nfa.process('a')
    nfa.process('ab')
    nfa.process('aba')
    nfa.process('abab')
    nfa.process('ababa')
    nfa.process('ababab')
    nfa.process('b')
    nfa.process('ba')
    nfa.process('bab')
    nfa.process('abba')

if __name__ == '__main__':
    main()
    
