# (c) 2009 Peter Goodman, all rights reserved
# Solutions for UWO CS3346a Assignment 1 Question 4

LEFT, RIGHT = 0, 1
INITIAL_STATE, GOAL_STATE = (3, 3, 0, 0, LEFT), (0, 0, 3, 3, RIGHT)

def state_is_legal(m1, c1, m2, c2, side):
    """Check if a state is legal. A legal state is not necessarily a reachable
    one."""
    if (c1 + c2) != 3 or (m1 + m2) != 3:
        return False
    elif m1 < 0 or m2 < 0 or c1 < 0 or c2 < 0:
        return False
    elif (m1 > 0 and m1 < c1) or (m2 > 0 and m2 < c2):
        return False
    elif (m1 + c1 == 6 and side == RIGHT) or (m2 + c2 == 6 and side == LEFT):
        return False
    return True

def gen_successor_states(m1, c1, m2, c2, side):
    """Generate the five possible successor states of a given state."""    
    if side is LEFT:
        yield m1-1, c1-1, m2+1, c2+1, RIGHT # 1 of both over
        yield m1-1, c1, m2+1, c2, RIGHT # 1 missionary over
        yield m1, c1-1, m2, c2+1, RIGHT # 1 cannibal over
        yield m1, c1-2, m2, c2+2, RIGHT # 2 cannibals over
        yield m1-2, c1, m2+2, c2, RIGHT # 2 missionaries over
    else:
        yield m1+1, c1+1, m2-1, c2-1, LEFT
        yield m1+1, c1, m2-1, c2, LEFT
        yield m1, c1+1, m2, c2-1, LEFT
        yield m1, c1+2, m2, c2-2, LEFT
        yield m1+2, c1, m2-2, c2, LEFT

def successors(state):
    """Return the set of all legal successor states of the given state."""
    states = set()
    for succ in gen_successor_states(*state):
        if state_is_legal(*succ):
            states.add(succ)
    return states

def state_label(state):
    """Create the label for each state."""
    return "<f0>%s%s|<f1>%s|<f2>%s%s" % (
        "M" * state[0], 
        "C" * state[1],
        state[4] == LEFT and "&larr;" or "&rarr;", 
        "M" * state[2], 
        "C" * state[3],
    )

def state_space():
    """Create the state space."""
    reachable, prev_reached, next_reached = set(), set(), set()
    space = { }
    
    prev_reached.add(INITIAL_STATE)
    prev_reached.add(GOAL_STATE)
    
    # generate the transitive closure of all legal states starting with the
    # initial and goal states
    while len(prev_reached):
        next_reached.clear()
        for state in prev_reached:
            space[state] = successors(state)
            reachable.add(state)
            
            # record newly reached
            next_reached.update(space[state])
        
        # update the various sets
        next_reached.difference_update(reachable)
        next_reached, prev_reached = prev_reached, next_reached
        reachable.update(prev_reached)
    
    return space

def print_state_space():
    "Print the state space as a directed graph in the DOT language."
    space = state_space()
    
    # make it easier to print out the graph
    keys = dict(zip(space.keys(), range(0, len(space))))
    styles = dict(zip(space.keys(), ("",) * len(space)))
    styles[GOAL_STATE], styles[INITIAL_STATE] = " color=red", " color=green"
    
    print "State Space:", "\n", "------------"
    
    for state in space:
        print keys[state], "[label=\"%s\" shape=record %s]" % (
            state_label(state), 
            styles[state],
        )
        for succ in space[state]:
            print keys[state], "->", keys[succ]

##############################################################################
       
class Node(object):
    """Represents a node in the search tree."""
    node_id = 0
    
    def __init__(self, state, depth):
        self.state, self.depth = state, depth
        self.id = Node.node_id
        Node.node_id += 1
    
    def __repr__(self):
        """What to show when printing a node. This is for debugging
        purposes."""
        return repr((self.state, self.depth, self.cost()))
    
    def cost(self):
        """Estimated cost of reaching the goal node from this node."""
        p1 = self.state[0] + self.state[1]
        cost = self.depth
        if p1 > 0:
            # if there are people then do as few trips as it takes to 
            # transfer them over
            cost += 1
            
            # normalize to the origin (left) side if there is a node over 
            # there, this is equivalent to a boatride back with one person
            if self.state[4] == RIGHT:
                cost += 1
                p1 = p1+1
        
        if p1 > 2:
            cost += (p1 - 2) * 2
        return cost

def search_steps():
    """Perform one step of the search, yielding the additions to the search 
    tree, the new fringe, and the ignored nodes at each step."""
    pqueue_cmp = lambda n1, n2: cmp(n1[0], n2[0])
    initial = Node(INITIAL_STATE, 0)
    fringe = [(initial.cost(), initial)]
    seen_states, ignored_nodes, expanded_nodes = set(), [ ], [initial,]
    
    while fringe:
        cost, node = fringe.pop(0)
        if node.state in seen_states:
            ignored_nodes.append(node)
            continue
        
        tree_nodes = [ ]
        seen_states.add(node.state)   
             
        if node.state == GOAL_STATE:
            yield node, [ ], fringe, ignored_nodes
            break
        else:
            for state in successors(node.state):
                succ = Node(state, node.depth+1)
                fringe.append((succ.cost(), succ))
                tree_nodes.append((node.id, succ.id))
            
            fringe.sort(cmp=pqueue_cmp)
            yield node, tree_nodes, fringe[:], ignored_nodes
            del ignored_nodes[:]
            del expanded_nodes[:]
    raise StopIteration()

def print_search_steps():
    """Print out a DOT graph representing the progress of each step as the A*
    algorithm performs a graph search over the state space of the Missionaries
    and Cannibals problem."""
    
    def print_node(node, in_fringe, current):
        print node.id, "[label=\"%s|<f3>%d\" shape=record%s%s]" % (
            state_label(node.state), 
            node.cost(),
            in_fringe and ", style=filled, fillcolor=oldlace" or "",
            current and ", style=filled, fillcolor=olivedrab" or "",
        )
            
    tree_history, node_history, ignored_history = [ ], [ ], [ ]
    for expanded_node, tree_nodes, fringe, ignored_nodes in search_steps():
        print "Tree:", "\n", "-----"
        
        # print the tree structure
        tree_history.extend(tree_nodes)
        for node_id, succ_id in tree_history:
            print node_id, "->", succ_id
        
        # print out the nodes that we ignored at this step
        ignored_history.extend(ignored_nodes)
        for node in ignored_history:
            print node.id, "[style=filled, fillcolor=grey88]"
        
        # print the node labels
        print_node(expanded_node, False, True)
        
        for node in node_history:
            print_node(node, False, False)
        
        for _, node in fringe:
            node_history.append(node)
            print_node(node, True, False)
        
        node_history.append(expanded_node)
        
        # print out the fringe as a DOT graph. This is somewhat of a hack in 
        # terms of how it forces DOT to lay things out in a single column of
        # states where each row has no more than 3 states in it
        print "\n", "Fringe", "\n", "------"
        for _, node in fringe:
            print_node(node, False, False)
        i = 0
        while i < len(fringe):
            print "{ rank=same", " ".join([str(n.id) for _, n in fringe[i:i+3]]), "}"
            if i < len(fringe) - 1:
                print " -> ".join([str(n.id) for _, n in fringe[i:i+3]]), "[color=white]"
            if i > 0:
                print fringe[i-3][1].id, "->", fringe[i][1].id, "[color=white]"
            i += 3
        
        print "\n", "---------------------------", "\n"

##############################################################################

if __name__ == "__main__":
    print_state_space()
    print 
    print_search_steps()
