Source code for fimdp.explicit

from collections import defaultdict

from .core import ProductConsMDP


[docs]def product_energy(cmdp, capacity, targets=[]): """Explicit encoding of energy into state-space The state-space of the newly created MDP consists of tuples `(s, e)`, where `s` is the state of the input CMDP and `e` is the energy level. For a tuple-state `(s,e)` and an action $a$ with consumption (in the input CMDP) `c`, all successors of the action `a` in the new MDP are of the form `(s', e-c)` for non-reload states and `(r, capacity)` for reload states. """ result = ProductConsMDP(cmdp, capacity) # The list of output states for which we have not yet # computed the successors. Items on this list are triplets # of the form `(s, e, p)` where `s` is the state # number in the mdp, `e` is the energy level, and p # is the state number in the output mdp. todo = [] otargets = [] sink_created = False # Transform a pair of state numbers (s, e) into a state # number in the output mdp, creating a new state if needed. # Whenever a new state is created, we can add it to todo. def dst(s, e): p = result.get_state(s, e) if p is None: p = result.new_state(s, e, reload=cmdp.is_reload(s)) if s in targets and e >= 0: otargets.append(p) todo.append((s, e, p)) return p # Initialization # For each state of mdp add a new initial state for s in range(cmdp.num_states): dst(s, capacity) # Build all states and edges in the product while todo: s, e, p = todo.pop() for a in cmdp.actions_for_state(s): # negative goes to sink if e - a.cons < 0: if not sink_created: sink = result.new_state(-1, "-∞", name="sink,-∞") result.add_action(sink, {sink: 1}, "σ", 1, None) sink_created = True result.add_action(p, {sink: 1}, a.label, a.cons, a) continue # build new distribution odist = {} for succ, prob in a.distr.items(): new_e = capacity if cmdp.is_reload(succ) else e - a.cons out_succ = dst(succ, new_e) odist[out_succ] = prob result.add_action(p, odist, a.label, a.cons, a) return result, otargets
# Decompose MDP into MECs. Ignores consumption. # # The algorithm uses decomposition of directed graph using # Tarjan's algorithm (single DFS). The implementation of this # the Tarjan's algo is inspired from: # https://www.geeksforgeeks.org/tarjan-algorithm-find-strongly-connected-components/
[docs]def get_MECs(mdp): """Given an MDP (not necessarly consMDP), compute its maximal-end-components decomposition. Returns list of mecs (lists). """ g = _mdp2graph(mdp) mecs = [] removed = set() # detect states of bSCCs while len(g.graph) > 0: to_remove = set() sccs = g.SCC() for scc_i, scc in enumerate(sccs): if g.check_bscc(scc) and not g.check_trivial(scc): to_remove.update(scc) mecs.append(scc) attr = _prob_attractor(mdp, removed.union(to_remove)) g.remove_vertices(attr) removed.update(attr) return mecs
class _Graph: ''' Represent graphs using adjacency list representation ''' def __init__(self): # default dictionary to store graph self.graph = defaultdict(list) self.Time = 0 self.sccs = [] # function to add an edge to graph def addEdge(self, u, v): self.graph[u].append(v) def _SCCUtil(self, u, low, disc, stackMember, st): '''A recursive function that find finds strongly connected components using DFS traversal u --> The vertex to be visited next disc[] --> Stores discovery times of visited vertices low[] -- >> earliest visited vertex (the vertex with minimum discovery time) that can be reached from subtree rooted with current vertex st -- >> To store all the connected ancestors (could be part of SCC) stackMember[] --> bit/index array for faster check whether a node is in stack ''' # Initialize discovery time and low value disc[u] = self.Time low[u] = self.Time self.Time += 1 stackMember[u] = True st.append(u) # Go through all vertices adjacent to this for v in self.graph[u]: # If v is not visited yet, then recur for it if disc[v] == -1: self._SCCUtil(v, low, disc, stackMember, st) # Check if the subtree rooted with v has a connection to # one of the ancestors of u # Case 1 (per above discussion on Disc and Low value) low[u] = min(low[u], low[v]) elif stackMember[v]: # Update low value of 'u' only if 'v' is still in stack # (i.e. it's a back edge, not cross edge). low[u] = min(low[u], disc[v]) # head node found, pop the stack and print an SCC w = -1 # To store stack extracted vertices if low[u] == disc[u]: scc = [] while w != u: w = st.pop() scc.append(w) stackMember[w] = False self.sccs.append(scc) def SCC(self): """Find strongly connected components using Tarjan's algorithm. Take `self.removed` into account: skip edges between removed nodes. Uses self._SCCUtil. """ # Mark all the vertices as not visited # and Initialize parent and visited, # and ap(articulation point) arrays disc = {u : -1 for u in self.graph.keys()} low = {u : -1 for u in self.graph.keys()} stackMember = {False : -1 for u in self.graph.keys()} st = [] self.sccs = [] # Call the recursive helper function # to find articulation points # in DFS tree rooted with vertex 'i' for u in self.graph: if disc[u] == -1: self._SCCUtil(u, low, disc, stackMember, st) return self.sccs def remove_vertices(self, to_remove): for u in to_remove: self.graph.pop(u, None) for u, succs in self.graph.items(): self.graph[u] = [v for v in succs if v not in to_remove] def check_bscc(self, scc): """Check if scc is botton and non-trivial -- has cycle""" for s in scc: for t in self.graph[s]: if t not in scc: return False return True def check_trivial(self, scc): for s in scc: for t in self.graph[s]: if t in scc: return False return True def _prob_attractor(mdp, attr): """ Compute a set of states from which I cannot avoid states from SCCs pointed by `scc_ind`. scc_indices are pointers to graph_sccs. We exploit that graph_sccs are in reverse topological order. """ repeat = True while repeat: repeat = False for s in range(mdp.num_states): if s in attr: continue safe = False for a in mdp.actions_for_state(s): if len(a.get_succs().intersection(attr)) == 0: safe = True if not safe: attr.add(s) repeat = True return attr def _mdp2graph(mdp): """Convert mdp to graph representation for the Tarjan algo""" ns = mdp.num_states # store the number of states g = _Graph() for s in range(ns): for t in mdp.state_succs(s): g.addEdge(s,t) return g