#!/usr/bin/python

import sys
import string
import math

# Constants (as close as we get in Python anyway)
SOLUTION_PENDING = 0
SOLUTION_FOUND = 1
NO_SOLUTION_EXISTS = 2

def get_user_integer (msg):
	"""Custom input handler, allows input of non-zero positive integers."""
	userstring = raw_input(msg)
	try:
		number = int(userstring)
	except TypeError:
		number = 1
	except:
		number = 1
	if number <= 0:
		number = 1
	
	return number

class Position (object):
	""" Representation of a 2-dimensional location."""

	def __init__ (self, x = 0, y = 0):
		self.x = x
		self.y = y

	def __eq__ (self, cmp_to):
		"""Method called for equality testing '=='."""
		return self.x == cmp_to.x and self.y == cmp_to.y

	def __str__ (self):
		"""'Stringification' method."""
		return "Position: (x, y) = [%d, %d]" % (self.x, self.y)
	
	def __add__ (self, rhs):
		"""Method called for the '+' operator."""
		return Position(self.x + rhs.x, self.y + rhs.y)

	def __sub__ (self, rhs):
		"""Method called for the '-' operator."""
		return Position(self.x - rhs.x, self.y - rhs.y)
	
	def __mul__ (self, rhs):
		"""Method called for the '*' operator."""
		return Position(self.x * rhs.x, self.y * rhs.y)

	def __hash__ (self):
		"""Method for hashing a Position object (used by dictionaries)."""
		return hash((self.x, self.y))

class Graphnode (object):
	"""Comparable, sortable based on f(), has a parent and children nodes."""

	def __init__ (self):
		self.parent = self
		self.f = 0
		self.g = 0
		self.h = 0
		self.location = None

	def __str__ (self):
		return "Graphnode: (f, g, h) = [%f, %f, %f]\n\t%s" % (self.f, self.g, self.h, str(self.location))
	
	def __cmp__ (self, cmp_to):
		"""This method is called by lists in order to sort their elements.
		Since we want our agent to sort the nodes by cost, in a raising
		manner, we implement comparision in terms of f()."""
		if self.f < cmp_to.f:
			return -1
		elif self.f > cmp_to.f:
			return 1
		else:
			return 0

	def __eq__ (self, cmp_to):
		"""Equality testing, returns True if the location is the same."""
		if self.location == cmp_to.location:
			return True
		return False

	def __ne__ (self, cmp_to):
		"""Negative equality testing, the inverse of the above method."""
		if self.location != cmp_to.location:
			return True
		return False
		
class Worldmap (object):
	"""Represents a square world of tiles of varying sort.
	Worldmap is loaded from a textual file by the load_map()-method."""

	def __init__ (self):
		self.filename = ""
		self.width = 0
		self.height = 0
		self.case_sensitive = False
		self.map_data = []
		self.map_type = "binary"

	def load_map (self, filename):
		"""This is a simple map parser (hack is a more proper name).
		Based on a read-by-line approach. It can handle both DOS, UNIX and MAC-
		formatted files, and tries to work out if the user is supplying us
		with malformed data. It returns a dictionary with the loaded map's start
		and goal positions."""

		self.filename = filename
		inputfile = open(self.filename, "r")
		
		# Create the map object (just a list). We set it to 'None' if no
		# map has been loaded (in the constructur, '__init__').
		self.map_data = []
		stuff = { }

		for i in range (0, 2):
			# Read a line, split by '=', clean up and interpret.
			line = inputfile.readline()
			parts = line.split('=')
			clean_parts = []
			for j in parts:
				clean_parts.append(j.strip())
			
			if clean_parts[0] == "height":
				self.height = int(clean_parts[1])
			elif clean_parts[0] == "width":
				self.width = int(clean_parts[1])
		
		num_ignored = 0 # Use this to keep track of ignored symbols.
		for i in range(self.height):
			line = inputfile.readline()
			newline = []

			if len(line) < 41:
				print "Warning: Current line length is less than 40 chars!"

			for j in range(len(line)):
				
				# 'i' indicates the current row, 'j' indicates the current
				# column. We'll assign the current symbol (or "token") to
				# 'mapsymbol' for convenience.
				mapsymbol = line[j]
				
				# If we are not sensitive about case in the input stream,
				# just transform all tokens to uppercase. The internal map
				# is always in uppercase.
				if not self.case_sensitive:
					mapsymbol = mapsymbol.upper()
				
				# If true, map type is unset. Set it up now.
				# FIXME: Hacked to "binary"-style for now...
				if self.map_type == None:
					self.map_type = "binary"
				
				# This is kinda ugly at first sight, but it's very flexible
				# if I'd want to change the format of the internal map.
				# This is a full substitution/action table for input stream
				# tokens. 
				if mapsymbol == "X":
					newsymbol = "X"
				elif mapsymbol == " ":
					newsymbol = " "
				# S denotes the agent's starting position. For a binary-type
				# map, store the position and replace it with an empty square
				# in the internal map.
				elif mapsymbol == "S":
					if self.map_type == "binary":
						newsymbol = " "
					else:
						raise ValueError("Unable to handle S symbols in non-binary maps. No valid substitution.")
					stuff["start"] = Position(j, i)
				# E denotes the agent's goal position. Replace with an empty
				# square in a binary-style map.
				elif mapsymbol == "E":
					if self.map_type == "binary":
						newsymbol = " "
					else:
						raise ValueError("Unable to handle E symbols in non-binary maps. No valid substitution.")
					stuff["goal"] = Position(j, i)
				elif mapsymbol == range(0, 9):
					newsymbol = mapsymbol

				# We ignore newlines. Note the stripped char and signal that
				# we don't want it added.
				elif mapsymbol == '\n' or mapsymbol == '\r':
					num_ignored = 1 + num_ignored
					newsymbol = None
				else:
					raise ValueError("Invalid symbol found in map file.")
				
				if newsymbol != None:
					newline.append(newsymbol)

			self.map_data.append(newline)
		
		# Whine a little, the users deserve it you know...
		print "Warning: %d symbols in the input stream were ignored (newlines?)." % num_ignored
		inputfile.close()

		if self.width != len(self.map_data[0]):
			raise ValueError("Map width does not match matrix contents.")

		if self.height != len(self.map_data):
			raise ValueError("Map height doesn not match matrix contents.")

		# Return the locations of the 'start' and 'goal' nodes.
		return stuff
	
	def valid_map_location (self, pos):
		"""Validates a given map location.
		The position is valid if within map bounds an does't contain an 'X'."""
		if pos.x >= self.width or pos.x < 0:
			return False
		if pos.y >= self.height or pos.y < 0:
			return False
		if self.map_data[pos.y][pos.x] == "X":
			return False

		return True
	
	def get_data (self, pos):
		"""Convinience map data access method.
		Less error-prone than manual access, since matrix is (y, x) ."""
		return self.map_data[pos.y][pos.x]

	def get_adjacent (self, pos):
		"""Return all the accessible locations based on the supplied one.
		Return value is a dictionary with Position instances as keys, and
		the map symbols as values. Only valid locations are returned."""
		
		adjacent = { }

		# Evaluate what directions that constitutes valid moves.
		test = Position(pos.x, pos.y + 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x, pos.y - 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x + 1, pos.y)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x - 1, pos.y)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x + 1, pos.y + 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x + 1, pos.y - 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x - 1, pos.y + 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)

		test = Position(pos.x - 1, pos.y - 1)
		if self.valid_map_location(test):
			adjacent[test] = self.get_data(test)
		
		return adjacent

class Scenario (object):
	"""A problem that an agent must solve.
	It has a worldmap, a start and a goal location."""
	
	def __init__ (self, filename):
		self.world = Worldmap()
		locations = self.world.load_map(filename)
		self.start = locations["start"]
		self.goal = locations["goal"]
	
	def setup (self, filename):
		self.world = Worldmap()
		locations = self.world.load_map(filename)
		self.start = locations["start"]
		self.goal = locations["goal"]

class Agent (object):
	"""An path-finding A*-based agent."""
	
	def __init__ (self):
		self.open = []
		self.closed = []
		self.task = None
		self.location = None
		self.iterations = 0
		self.path = []
		self.cost_function = self.cost_simple
		self.heuristic_function = self.heuristic_manhattan
	
	def setup (self, task):
		"""Initializes the agent to work with the supplied task."""
		self.open = []
		self.closed = []
		self.task = task
		self.location = task.start
		initial_node = Graphnode()
		initial_node.location = self.location
		initial_node.parent = initial_node
		self.open.append(initial_node)
		self.iterations = 0
		self.path = []
		self.cost_function = self.cost_simple
		self.heuristic_function = self.heuristic_manhattan

	def heuristic_manhattan (self, loc):
		return math.fabs(self.task.goal.x - loc.x) + math.fabs(self.task.goal.y - loc.y)

	def heuristic_straight_line (self, loc):
		return math.hypot(self.task.goal.x - loc.x, self.task.goal.y - loc.y)
	
	def cost_simple (self, node):
		"""Returns the length of the current path to the start (root) node."""
	
		cost = 0
		tmp = node
		while tmp.location != self.task.start:
			tmp = tmp.parent
			cost = cost + 1

		return cost

	def cost_numeric (self, node):
		"""Returns the cost of the current path, for a numeric map."""
		pass

	def direction_penalty (self, cur_node, test_node):
		"""Return the directional penalty for using 'node' as a successor.
		This method creates a vector A from the current and it's parent node,
		as well as a vector B from the current node and 'node'. The returned
		cost is the dot product of those two vector."""
		
		if test_node == cur_node:
			raise ValueError("Testing a node against itself.")
		
		if cur_node.parent == cur_node:
			return 0

		# Create the vectors	
		vec_a = cur_node.location - cur_node.parent.location
		vec_b = cur_node.location - test_node.location

		# Hacked up do product.
		return vec_a.x * vec_b.x + vec_a.y * vec_b.y

	def get_path (self, node):
		"""Builds a list of all nodes from the root node to the referred."""

		path = []

		while tmpnode != self.task.start:
			path.append(tmpnode)
			tmpnode = tmpnode.parent

		return path

	def think (self):
		"""The main solving iteration method.
		Call this repeatedly in order to solve the current task.
		The return value tells you wether the task is currently unsolved,
		successfuly solved, or unsolvable."""

		# If the open list is exhausted, we have failed utterly...
		if len(self.open) == 0:
			return NO_SOLUTION_EXISTS
		
		# Sort the candidate nodes depending on f()-value.	
		self.open.sort()
		print "The open list:"
		for i in self.open:
			print i
		cur_node = self.open.pop(0)
		
		print "The closed list:"
		for i in self.closed:
			print i
	
		# If the node is the goal node, return success after storing path.
		self.location = cur_node.location
		self.closed.append(cur_node)
		if self.location == self.task.goal:
			self.path = self.get_path(cur_node)
			return SOLUTION_FOUND

		# Get all adjacent nodes from the map.
		adjacent = self.task.world.get_adjacent(cur_node.location)
		print "Got %d accessible locations from the map." % len(adjacent)

		# Encapsulate each one in a Grapnode, and add the costs and heuristics.
		nodes = []
		for adj_loc in adjacent:
			new_node = Graphnode()
			new_node.parent = cur_node
			new_node.location = adj_loc
			new_node.g = self.cost_function(new_node)
			new_node.h = self.heuristic_function(adj_loc)
			#new_node.h = new_node.h + self.direction_penalty(cur_node, new_node)

			new_node.f = new_node.g + new_node.h
			nodes.append(new_node)

		for node in nodes:
			
			in_open = False
			in_closed = False

			# If this node is already in the open list, ignore it.
			for o in self.open:
				if o.location == node.location:
					in_open = True
					print "found in closed: " + str(o.location)
					o.g = node.g
					o.f = o.g + o.h
					o.parent = node.parent
					continue
			
			# If in 'closed' and new is cheaper, update 'parent' and 'g'.
			if in_open == False:
				for c in self.closed:
					if c.location == node.location and node.g < c.g:
						in_closed = True
						print "found in closed: " + str(c.location)
						c.g = node.g
						c.f = c.g + c.h
						c.parent = node.parent
						continue
			if not in_closed and not in_open:
				print "added node " + str(node.location)
				self.open.append(node)
	
		self.iterations = self.iterations + 1
		print "Iteration #%d" % self.iterations
		
		return SOLUTION_PENDING

	def show (self):
		"""Prints the current state of the agent's task solving progress."""
		
		# Print lines one by one, check if the location has any data about it
		# in any of the lists, or is a special location (start, goal, agent).
		# If it is any such location, print a special symbol.	
		for y, row in enumerate(self.task.world.map_data):

			line = ""
			replacement = ""
			for x, symbol in enumerate(row):
				replacement = symbol
				cur_pos = Position(x, y)
				if cur_pos == self.task.goal:
					replacement = "E"
				elif cur_pos == self.task.start:
					replacement = "S"
				elif cur_pos == self.location:
					replacement = "A"
				else:
				
					for i in self.open:
						if i.location == cur_pos:
							replacement = "+"
							continue
	
					for i in self.closed:
						if i.location == cur_pos:
							replacement = "-"
							continue
					
					for i in self.path:
						if i.location == cur_pos:
							replacement = "o"
							continue

				line = line + replacement

			print line
		
		# Print explanatory string to please the user.
		print "+ (open)  - (closed)  o (final path)"


### Main script start ###

try:
	print "Using map file: \"%s\"" % sys.argv[1]

except IndexError:
	print "Error: Please supply a map filename as argument to this script."

try:
	myscenario = Scenario(sys.argv[1])

except IOError:
	print "Error: Unable to open map: \"%s\"" % sys.argv[1]

myagent = Agent()
myagent.setup(myscenario)

print "\nSelect simulation speed/visualization:\n"
print "1. Interactive: Redraw and wait after each succession."
print "2. Automatic: Full speed, no drawing until solution has been found."

mode = get_user_integer("\nWhat mode do you wish to use? (default: Interactive): ")

print "\nSelect the agent's heuristic function:\n"
print "1. Manhattan distance"
print "2. Straight line distance"

heur = get_user_integer("\nWhich heuristic do you want to use? (default: Manhattan): ")
myagent.show()
raw_input("\nPress [enter] to begin simulation...")

# Loop until we have solved the stuff, or it b0rks completely.
ret = 0
usercounter = 0
while ret == SOLUTION_PENDING:
	ret = myagent.think()
	if mode == 1 and usercounter == 0:
		myagent.show()
		usercounter = get_user_integer("\nNumber of iterations to run non-interactively (default = 1)... ")
	usercounter = usercounter - 1

myagent.show()
if ret == NO_SOLUTION_EXISTS:
	print "A solution could not be obtained for this problem."
elif ret == SOLUTION_FOUND:
	print "A solution was found after %d iterations." % myagent.iterations



