Source code for leuvenmapmatching.matcher.base

# encoding: utf-8
"""
leuvenmapmatching.matcher.base
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Base Matcher and Matching classes.

This a generic base class to be used by matchers. This class itself
does not implement a working matcher.

:author: Wannes Meert
:copyright: Copyright 2015-2021 DTAI, KU Leuven and Sirris.
:license: Apache License, Version 2.0, see LICENSE for details.
"""
from __future__ import print_function

import math
import sys
import logging
import time
from collections import OrderedDict, defaultdict, namedtuple
from itertools import islice
from typing import List, Tuple, Dict, Any, Optional, Set

import numpy as np

from ..util.segment import Segment
from ..util import approx_equal, approx_leq


logger = logging.getLogger("be.kuleuven.cs.dtai.mapmatching")
approx_value = 0.0000000001
ema_const = namedtuple('EMAConst', ['prev', 'cur'])(0.7, 0.3)
default_label_width = 25


[docs] class BaseMatching(object): """Matching object that represents a node in the Viterbi lattice.""" __slots__ = ['matcher', 'edge_m', 'edge_o', 'logprob', 'logprobema', 'logprobe', 'logprobne', 'obs', 'obs_ne', 'dist_obs', 'prev', 'prev_other', 'stop', 'length', 'delayed'] def __init__(self, matcher: 'BaseMatcher', edge_m: Segment, edge_o: Segment, logprob=-np.inf, logprobema=-np.inf, logprobe=-np.inf, logprobne=-np.inf, dist_obs: float = 0.0, obs: int = 0, obs_ne: int = 0, prev: Optional[Set['BaseMatching']] = None, stop: bool = False, length: int = 1, delayed: int = 0, **_kwargs): """ :param matcher: Reference to the Matcher used to generate this matching object. :param edge_m: Segment in the given graph (thus line between two nodes in the graph). :param edge_o: Segment in the given observations (thus line in between two observations). :param logprob: Log probability of this matching. :param logprobema: Exponential Mean Average of Log probability. :param logprobe: Emitting :param logprobne: Non-emitting :param dist_obs: Distance between map point and observation :param obs: Reference to path entry index (observation) :param obs_ne: Number of non-emitting states for this observation :param prev: Previous best matching objects :param stop: Stop after this matching (e.g. because probability is too low) :param length: Lenght of current matching sequence through lattice. :param delayed: This matching is temporarily stopped if >0 (e.g. to first explore better options). :param dist_m: Distance over graph :param dist_o: Distance over observations :param _kwargs: """ self.edge_m: Segment = edge_m self.edge_o: Segment = edge_o self.logprob: float = logprob # max probability self.logprobe: float = logprobe # Emitting self.logprobne: float = logprobne # Non-emitting self.logprobema: float = logprobema # exponential moving average log probability # TODO: Not used anymore? self.obs: int = obs # reference to path entry index (observation) self.obs_ne: int = obs_ne # number of non-emitting states for this observation self.dist_obs: float = dist_obs # Distance between map point and observation self.prev: Set[BaseMatching] = set() if prev is None else prev # Previous best matching objects self.prev_other: Set[BaseMatching] = set() # Previous matching objects with lower logprob self.stop: bool = stop self.length: int = length self.delayed: int = delayed self.matcher: BaseMatcher = matcher @property def prune_value(self): """Pruning the lattice (e.g. to delay) is based on this key.""" return self.logprob # return self.logprobema
[docs] def next(self, edge_m: Segment, edge_o: Segment, obs: int = 0, obs_ne: int = 0): """Create a next lattice Matching object with this Matching object as the previous one in the lattice.""" new_stop = False if edge_m.is_point() and edge_o.is_point(): # node to node dist = self.matcher.map.distance(edge_m.p1, edge_o.p1) # proj_m = edge_m.p1 # proj_o = edge_o.pi elif edge_m.is_point() and not edge_o.is_point(): # node to edge dist, proj_o, t_o = self.matcher.map.distance_point_to_segment(edge_m.p1, edge_o.p1, edge_o.p2) # proj_m = edge_m.p1 edge_o.pi = proj_o edge_o.ti = t_o elif not edge_m.is_point() and edge_o.is_point(): # edge to node dist, proj_m, t_m = self.matcher.map.distance_point_to_segment(edge_o.p1, edge_m.p1, edge_m.p2) if not self.matcher.only_edges and (approx_equal(t_m, 0.0) or approx_equal(t_m, 1.0)): if __debug__ and logger.isEnabledFor(logging.DEBUG): logger.debug(f" | Stopped trace: Too close to end, {t_m}") new_stop = True else: return None edge_m.pi = proj_m edge_m.ti = t_m # proj_o = edge_o.pi elif not edge_m.is_point() and not edge_o.is_point(): # edge to edge dist, proj_m, proj_o, t_m, t_o = self.matcher.map.distance_segment_to_segment(edge_m.p1, edge_m.p2, edge_o.p1, edge_o.p2) edge_m.pi = proj_m edge_m.ti = t_m edge_o.pi = proj_o edge_o.ti = t_o else: raise Exception(f"Should not happen") logprob_trans, props_trans = self.matcher.logprob_trans(self, edge_m, edge_o, is_prev_ne=(self.obs_ne != 0), is_next_ne=(obs_ne != 0)) logprob_obs, props_obs = self.matcher.logprob_obs(dist, self, edge_m, edge_o, is_ne=(obs_ne != 0)) if __debug__ and logprob_trans > 0: raise Exception(f"logprob_trans = {logprob_trans} > 0") if __debug__ and logprob_obs > 0: raise Exception(f"logprob_obs = {logprob_obs} > 0") new_logprob_delta = logprob_trans + logprob_obs if obs_ne == 0: new_logprobe = self.logprob + new_logprob_delta new_logprobne = 0 new_logprob = new_logprobe new_length = self.length + 1 else: # Non-emitting states require normalisation # "* e^(ne_length_factor_log)" or "- ne_length_factor_log" for every step to a non-emitting # state to prefer shorter paths new_logprobe = self.logprobe + self.matcher.ne_length_factor_log # The obvious choice would be average to compensate for that non-emitting states # create different path lengths between emitting nodes. # We use min() as it is a monotonic function, in contrast with an average new_logprobne = min(self.logprobne, new_logprob_delta) new_logprob = new_logprobe + new_logprobne # Alternative approach with an average # new_logprobne = self.logprobne + new_logprob_delta # "+ 1" to punish non-emitting states a bit less. Otherwise it would be # similar to (Pr_tr*Pr_obs)**2, which punishes just one non-emitting state too much. # new_logprob = new_logprobe + new_logprobne / (obs_ne + 1) new_length = self.length new_logprobema = ema_const.cur * new_logprob_delta + ema_const.prev * self.logprobema new_stop |= self.matcher.do_stop(new_logprob / new_length, dist, logprob_trans, logprob_obs) if __debug__ and new_logprob > self.logprob: raise Exception(f"Expecting a monotonic probability, " f"new_logprob = {new_logprob} > logprob = {self.logprob}") if not new_stop or (__debug__ and logger.isEnabledFor(logging.DEBUG)): m_next = self.__class__(self.matcher, edge_m, edge_o, logprob=new_logprob, logprobne=new_logprobne, logprobe=new_logprobe, logprobema=new_logprobema, obs=obs, obs_ne=obs_ne, prev={self}, dist_obs=dist, stop=new_stop, length=new_length, delayed=self.delayed, **props_trans, **props_obs) return m_next else: return None
[docs] @classmethod def first(cls, logprob_init, edge_m, edge_o, matcher, dist_obs): """Create an initial lattice Matching object.""" logprob_obs, props_obs = matcher.logprob_obs(dist_obs, None, edge_m, edge_o) logprob = logprob_init + logprob_obs new_stop = matcher.do_stop(logprob, dist_obs, logprob_init, logprob_obs) if not new_stop or logger.isEnabledFor(logging.DEBUG): m_next = cls(matcher, edge_m=edge_m, edge_o=edge_o, logprob=logprob, logprobema=logprob, logprobe=logprob, logprobne=0, dist_obs=dist_obs, obs=0, stop=new_stop, **props_obs) return m_next else: return None
[docs] def update(self, m_next): """Update the current entry if the new matching object for this state is better. :param m_next: The new matching object representing the same node in the lattice. :return: True if the current object is replaced, False otherwise """ # if self.length != m_next.length: # slogprob_norm = self.logprob / self.length # nlogprob_norm = m_next.logprob / m_next.length # else: # slogprob_norm = self.logprob # nlogprob_norm = m_next.logprob # if (self.stop == m_next.stop and slogprob_norm < nlogprob_norm) or (self.stop and not m_next.stop): # self._update_inner(m_next) # return True # elif abs(slogprob_norm - nlogprob_norm) < approx_value and self.stop == m_next.stop: # self.prev.update(m_next.prev) # self.stop = m_next.stop # return False assert self.length == m_next.length if (self.stop and not m_next.stop) \ or (self.stop == m_next.stop and self.logprob < m_next.logprob): self._update_inner(m_next) return True else: self.prev_other.update(m_next.prev) return False
def _update_inner(self, m_other: 'BaseMatching'): self.edge_m = m_other.edge_m self.edge_o = m_other.edge_o self.logprob = m_other.logprob self.logprobe = m_other.logprobe self.logprobne = m_other.logprobne self.logprobema = m_other.logprobema self.dist_obs = m_other.dist_obs self.obs = m_other.obs self.obs_ne = m_other.obs_ne self.prev_other.update(self.prev) # Do we use this? self.prev = m_other.prev self.stop = m_other.stop self.delayed = m_other.delayed self.length = m_other.length def is_nonemitting(self): return self.obs_ne != 0 def is_emitting(self): return self.obs_ne == 0 def last_emitting_logprob(self): if self.is_emitting(): return self.logprob elif self.prev is None or len(self.prev) == 0: return 0 else: return next(iter(self.prev)).last_emitting_logprob() def __str__(self, label_width=None): stop = '' if self.stop: stop = 'x' else: stop = f'{self.delayed}' if label_width is None: label_width = default_label_width repr_tmpl = "{:<2} | {:<"+str(label_width)+"} | {:10.5f} | {:10.5f} | {:10.5f} | {:10.5f} | " +\ "{:<3} | {:10.5f} | {:<" + str(label_width) + "} |" return repr_tmpl.format(stop, self.label, self.logprob, self.logprob / self.length, self.logprobema, self.logprobne, self.obs, self.dist_obs, ",".join([str(prev.label) for prev in self.prev])) def __repr__(self): return "Matching<"+str(self.label)+">" @staticmethod def repr_header(label_width=None, stop=""): if label_width is None: label_width = default_label_width repr_tmpl = "{:<2} | {:<"+str(label_width)+"} | {:<10} | {:<10} | {:<10} | {:<10} | " + \ "{:<3} | {:<10} | {:<"+str(label_width)+"} |" return repr_tmpl.format(stop, "", "lg(Pr)", "nlg(Pr)", "slg(Pr)", "lg(Pr-ne)", "obs", "d(obs)", "prev") @staticmethod def repr_static(fields, label_width=None): if label_width is None: label_width = default_label_width default_fields = ["", "", float('nan'), float('nan'), float('nan'), float('nan'), "", float('nan'), "", ""] repr_tmpl = "{:<2} | {:<" + str(label_width) + "} | {:10.5f} | {:10.5f} | {:10.5f} | {:10.5f} | " + \ "{:<3} | {:10.5f} | {:<" + str(label_width) + "} |" if len(fields) < 8: fields = list(fields) + default_fields[len(fields):] return repr_tmpl.format(*fields) @property def label(self): if self.edge_m.p2 is None: return "{}---{}-{}".format(self.edge_m.l1, self.obs, self.obs_ne) else: return "{}-{}-{}-{}".format(self.edge_m.l1, self.edge_m.l2, self.obs, self.obs_ne) @property def cname(self): if self.edge_m.l2 is None: return "{}_{}_{}".format(self.edge_m.l1, self.obs, self.obs_ne) else: return "{}_{}_{}_{}".format(self.edge_m.l1, self.edge_m.l2, self.obs, self.obs_ne) @property def key(self): """Key that indicates the node or edge, observation and non-emitting step. This is the unique key that is used in the lattice. """ if self.edge_m.l2 is None: return tuple([self.edge_m.l1, self.obs, self.obs_ne]) else: return tuple([self.edge_m.l1, self.edge_m.l2, self.obs, self.obs_ne]) @property def shortkey(self): """Key that indicates the node or edge. Irrespective of the current observation.""" if self.edge_m.l2 is None: return self.edge_m.l1 else: return tuple([self.edge_m.l1, self.edge_m.l2]) @property def nodes(self): if self.edge_m.l2 is None: return [self.edge_m.l1] else: return [self.edge_m.l1, self.edge_m.l2] def __hash__(self): return self.cname.__hash__() def __lt__(self, o): return self.logprob < o.logprob def __le__(self, o): return self.logprob <= o.logprob def __eq__(self, o): return self.logprob == o.logprob def __ne__(self, o): return self.logprob != o.logprob def __ge__(self, o): return self.logprob >= o.logprob def __gt__(self, o): return self.logprob > o.logprob
class LatticeColumn: def __init__(self, obs_idx): # 0 = obs, >0 = non-emitting between this obs and next self.obs_idx = obs_idx self.o = [] # type list[dict[label,Matching]] def __contains__(self, item): for c in self.o: if item in c: return True return False def __len__(self): return len(self.o) def set_delayed(self, delayed): """Update all delayed values.""" for c in self.o: for m in c.values(): m.delayed = delayed def dict(self, obs_ne=None): if obs_ne is None: raise AttributeError('obs_ne should be value') while obs_ne >= len(self.o): self.o.append({}) return self.o[obs_ne] def values_all(self): """All matches for the emitting layer and all non-emitting layers.""" values = set() for o in self.o: values.update(o.values()) return values def values(self, obs_ne=None): if obs_ne is None: raise AttributeError('obs_ne should be value') if len(self.o) <= obs_ne: return [] return self.o[obs_ne].values() def upsert(self, matching): # type: (BaseMatching) -> None if matching is None: return None while matching.obs_ne >= len(self.o): self.o.append({}) c = self.o[matching.obs_ne] if matching.key in c: other_matching = c[matching.key] # type: BaseMatching other_matching.update(matching) else: c[matching.key] = matching return c[matching.key] def prune(self, obs_ne, max_lattice_width, expand_upto, prune_thr=None): """Prune given column in the lattice to fit in max_lattice_width. Also ignore all matchings with a probability lower than prune_thr. These are matchings that are worse than the matchings at the next observation that are retained after pruning. :param obs_ne: :param max_lattice_width: :param expand_upto: The current expand level :return: """ cur_lattice = [m for m in self.values(obs_ne) if not m.stop] if __debug__: logger.debug('Prune lattice[{},{}] from {} to {}, with prune thr {}' .format(self.obs_idx, obs_ne, len([m for m in cur_lattice if not m.stop and m.delayed == expand_upto]), max_lattice_width, prune_thr)) cnt_pruned = 0 if max_lattice_width is not None and len(cur_lattice) > max_lattice_width: ms = sorted(cur_lattice, key=lambda t: t.prune_value, reverse=True) cur_width = max_lattice_width m_last = ms[cur_width - 1] # Extend current width if next pruned matching has same logprob as last kept matching # This increases the lattice width but otherwise the algorithm depends on the # order of edges/nodes and is not deterministic. while cur_width < len(ms) and ms[cur_width].prune_value == m_last.prune_value: m_last = ms[cur_width] cur_width += 1 if prune_thr is not None: while cur_width > 0 and ms[cur_width - 1].prune_value < prune_thr: cur_width -= 1 for m in ms[:cur_width]: # type: BaseMatching if m.delayed > expand_upto: m.delayed = expand_upto # expand now for m in ms[cur_width:]: if m.delayed <= expand_upto: if __debug__: cnt_pruned += 1 m.delayed = expand_upto + 1 # expand later if cur_width > 0: prune_thr = ms[cur_width - 1].prune_value if __debug__: logger.debug(f'Pruned {cnt_pruned} matchings, return {prune_thr}') return prune_thr
[docs] class BaseMatcher: def __init__(self, map_con, obs_noise=1, max_dist_init=None, max_dist=None, min_prob_norm=None, non_emitting_states=True, max_lattice_width=None, only_edges=True, obs_noise_ne=None, matching=BaseMatching, non_emitting_length_factor=0.75, **kwargs): """Initialize a matcher for map matching. This a generic base class to be used by matchers. This class itself does not implement a working matcher. Distances are in meters when using latitude-longitude. :param map_con: Map object to connect to map database :param obs_noise: Standard deviation of noise :param obs_noise_ne: Standard deviation of noise for non-emitting states (is set to obs_noise if not give) :param max_dist_init: Maximum distance from start location (if not given, uses max_dist) :param max_dist: Maximum distance from path (this is a hard cut, min_prob_norm should be better) :param min_prob_norm: Minimum normalized probability of observations (ema) :param non_emitting_states: Allow non-emitting states. A non-emitting state is a state that is not associated with an observation. Here we assume it can be associated with a location in between two observations to allow for pruning. It is advised to set min_prob_norm and/or max_dist to avoid visiting all possible nodes in the graph. :param max_lattice_width: Only continue from a limited number of states (thus locations) for a given observation. This possibly speeds up the matching by a lot. If there are more possible next states, the states with the best likelihood so far are selected. The other states are 'delayed'. If the matching is continued later with a larger value using `increase_max_lattice_width`, the algorithms continuous from these delayed states. :param only_edges: Do not include nodes as states, only edges. This is the typical setting for HMM methods. :param matching: Matching type :param non_emitting_length_factor: Reduce the probability of a sequence of non-emitting states the longer it is. This can be used to prefer shorter paths. This is separate from the transition probabilities because transition probabilities are averaged for non-emitting states and thus the length is also averaged out. To define a custom transition and/or emission probability distribtion, overwrite the following functions: - :meth:`logprob_trans` - :meth:`logprob_obs` """ self.map = map_con # type: BaseMap if max_dist: self.max_dist = max_dist else: self.max_dist = np.inf if max_dist_init: self.max_dist_init = max_dist_init else: self.max_dist_init = self.max_dist if min_prob_norm: self.min_logprob_norm = math.log(min_prob_norm) else: self.min_logprob_norm = -np.inf logger.debug(f"Matcher.min_logprob_norm = {self.min_logprob_norm}, Matcher.max_dist = {self.max_dist}") self.obs_noise = obs_noise if obs_noise_ne is None: self.obs_noise_ne = obs_noise else: self.obs_noise_ne = obs_noise_ne self.path = None self.lattice = None # type: Optional[dict[int,LatticeColumn]] # Best path through lattice: self.lattice_best = None # type: Optional[list[BaseMatching]] self.node_path = None # type: Optional[list[str]] self.matching = matching self.non_emitting_states = non_emitting_states # type: bool self.non_emitting_states_maxnb = 100 self.max_lattice_width = max_lattice_width # type: Optional[int] self.only_edges = only_edges # type: bool self.expand_now = 0 # all m.delayed <= expand_upto will be expanded self.early_stop_idx = None # Penalties self.ne_length_factor_log = math.log(non_emitting_length_factor)
[docs] def logprob_trans(self, prev_m, edge_m, edge_o, is_prev_ne=False, is_next_ne=False): # type: (BaseMatcher, BaseMatching, Segment, Segment, bool, bool) -> Tuple[float, Dict[str, Any]] """Transition probability. Note: In contrast with a regular HMM, this cannot be a probability density function, it needs to be a proper probability (thus values between 0.0 and 1.0). :return: probability, properties that are passed to the matching object """ return 0, {} # All probabilities are 1 (thus technically not a distribution)
[docs] def logprob_obs(self, dist, prev_m, new_edge_m, new_edge_o, is_ne=False): """Emission probability. Note: In contrast with a regular HMM, this cannot be a probability density function, it needs to be a proper probability (thus values between 0.0 and 1.0). :return: probability, properties that are passed to the matching object """ return 0, {}
[docs] def match_gpx(self, gpx_file, unique=True): """Map matching from a gpx file""" from ..util.gpx import gpx_to_path path = gpx_to_path(gpx_file) return self.match(path, unique=unique)
def do_stop(self, logprob_norm, dist, logprob_trans, logprob_obs): if logprob_norm < self.min_logprob_norm: logger.debug(f" | Stopped trace: norm(log(Pr)) too small: {logprob_norm} < {self.min_logprob_norm}" f" -- lPr_t = {logprob_trans:.3f}, lPr_o = {logprob_obs:.3f}") return True if dist > self.max_dist: logger.debug(f" | Stopped trace: distance too large: {dist} > {self.max_dist}") return True return False def _insert(self, m_next): return self.lattice[m_next.obs].upsert(m_next)
[docs] def match(self, path, unique=False, tqdm=None, expand=False): """Dynamic Programming based (HMM-like) map matcher. If the matcher fails to match the entire path, the last matched index is returned. This index can be used to run the matcher again from that observation onwards. :param path: list[Union[tuple[lat, lon], tuple[lat, lon, time]] :param unique: Only retain unique nodes in the sequence (avoid repetitions) :param tqdm: Use a tqdm progress reporter (default is None) :param expand: Expand the current lattice (delayed matches) :return: Tuple of (List of BaseMatching, index of last observation that was matched) """ if __debug__: logger.debug("Start matching path of length {}".format(len(path))) # Initialisation if expand: self.expand_now += 1 if self.path != path: is_path_extended = True if len(path) > len(self.path): for pi, spi in zip(path, self.path): if pi != spi: is_path_extended = False break else: is_path_extended = False if is_path_extended: self.lattice[len(self.path) - 1].set_delayed(self.expand_now) for obs_idx in range(len(self.path), len(path)): if obs_idx not in self.lattice: self.lattice[obs_idx] = LatticeColumn(obs_idx) self.path = path else: raise Exception(f'Cannot expand for a new path, should be the same path (or an extension).') else: self.path = path self.expand_now = 0 nb_start_nodes = self._create_start_nodes(use_edges=self.only_edges) if nb_start_nodes == 0: self.lattice_best = [] return [], 0 if __debug__ and logger.isEnabledFor(logging.DEBUG): self.print_lattice(obs_idx=0, label_width=default_label_width, debug=True) # Start iterating over observations 1..end t_start = time.time() iterator = range(1, len(path)) if tqdm: iterator = tqdm(iterator) self.early_stop_idx = None for obs_idx in iterator: if __debug__: logger.debug("--- obs {} --- {} ---".format(obs_idx, self.path[obs_idx])) # check if early stopping has occured cnt_lat_size_not_zero = False for m_tmp in self.lattice[obs_idx - 1].values(0): if not m_tmp.stop: cnt_lat_size_not_zero = True break # if len(self.lattice[obs_idx - 1]) == 0: if not cnt_lat_size_not_zero: if __debug__: logger.debug("No solutions found anymore") self.early_stop_idx = obs_idx - 1 logger.info(f'Stopped early at observation {self.early_stop_idx}') break # Expand matches self._match_states(obs_idx) if self.non_emitting_states: # Fill in non-emitting states between previous and current observation self._match_non_emitting_states(obs_idx - 1, expand=expand) if self.max_lattice_width: # Prune again if non_emitting_states reactives matches from match_states self.lattice[obs_idx].prune(0, self.max_lattice_width, self.expand_now) if __debug__ and logger.isEnabledFor(logging.DEBUG): self.print_lattice(obs_idx=obs_idx, label_width=default_label_width, debug=True) logger.debug(f"--- end obs {obs_idx} ---") t_delta = time.time() - t_start logger.info("--- end ---") logger.info("Build lattice in {} seconds".format(t_delta)) # Backtrack to find best path if not self.early_stop_idx: one_no_stop = False for m in self.lattice[len(path) - 1].values_all(): # todo: could be values(0) ? if not m.stop: one_no_stop = True break if not one_no_stop: self.early_stop_idx = len(path) - 1 if self.early_stop_idx is not None: if self.early_stop_idx == 0: self.lattice_best = [] return [], 0 start_idx = self.early_stop_idx - 1 else: start_idx = len(self.path) - 1 node_path = self._build_node_path(start_idx, unique) return node_path, start_idx
def _skip_ne_states(self, prev_m): # type: (BaseMatcher, BaseMatching) -> bool return False def _create_start_nodes(self, use_edges=True): """Find those nodes that are close to the first point in the path. :return: Number of created start points. """ # Initialisation on first observation if self.expand_now > 0: # No need to search for new points, only activate delayed matches self.lattice[0].prune(0, self.max_lattice_width, self.expand_now) return len(self.lattice[0]) t_start = time.time() self.lattice = dict() for obs_idx in range(len(self.path)): self.lattice[obs_idx] = LatticeColumn(obs_idx) if use_edges: nodes = self.map.edges_closeto(self.path[0], max_dist=self.max_dist_init) else: nodes = self.map.nodes_closeto(self.path[0], max_dist=self.max_dist_init) if __debug__: logger.debug("--- obs {} --- {} ---".format(0, self.path[0])) t_delta = time.time() - t_start logger.info("Initialized lattice with {} starting points in {} seconds".format(len(nodes), t_delta)) if len(nodes) == 0: logger.info(f'Stopped early at observation 0' f', no starting points/edges x found for which ' f'|x - ({self.path[0][0]:.2f},{self.path[0][1]:.2f})| < {self.max_dist_init}') return 0 if __debug__: logger.debug(self.matching.repr_header()) logprob_init = 0 # math.log(1.0/len(nodes)) if use_edges: # Search for nearby edges for dist_obs, label1, loc1, label2, loc2, pi, ti in nodes: if label2 == label1: continue edge_m = Segment(label1, loc1, label2, loc2, pi, ti) edge_o = Segment(f"O{0}", self.path[0]) m_next = self.matching.first(logprob_init, edge_m, edge_o, self, dist_obs) if m_next is not None: self.lattice[0].upsert(m_next) if __debug__: logger.debug(str(m_next)) else: # Search for nearby nodes for dist_obs, label, loc in nodes: edge_m = Segment(label, loc) edge_o = Segment(f"O{0}", self.path[0]) m_next = self.matching.first(logprob_init, edge_m, edge_o, self, dist_obs) if m_next is not None: self.lattice[0].upsert(m_next) if __debug__: logger.debug(str(m_next)) if self.max_lattice_width: self.lattice[0].prune(0, max_lattice_width=self.max_lattice_width, expand_upto=self.expand_now) # if self.non_emitting_states: # self._match_non_emitting_states(0, path) return len(self.lattice[0]) def increase_delayed(self, expand_from=None): if expand_from is None: expand_from = self.expand_now + 1 for col in self.lattice.values(): for colo in col.o: for m in colo.values(): if m.delayed >= expand_from: m.delayed += 1 def _match_states(self, obs_idx, prev_lattice=None, max_dist=None, inc_delayed=False): """Match states :param obs_idx: :param prev_lattice: Start from this list instead of the previous column in the lattice :param max_dist: Use map.*_closeto instead of map.*_nbrto :param inc_delayed: Increase delayed property when new state is created :return: True is new states have been found, False otherwise. """ if prev_lattice is None: prev_lattice = [m for m in self.lattice[obs_idx - 1].values(0) if not m.stop and m.delayed == self.expand_now] count = 0 for m in prev_lattice: # type: BaseMatching if m.stop: assert False # should not happen continue count += 1 if m.edge_m.is_point(): # == Move to neighbour from node == if max_dist is None: nbrs = self.map.nodes_nbrto(m.edge_m.l1) else: nbrs = self.map.nodes_closeto(m.edge_m.p1, max_dist=max_dist) # print("Neighbours for {}: {}".format(m, nbrs)) if nbrs is None: if __debug__: logger.debug("No neighbours found for node {}".format(m.edge_m.l1)) continue if __debug__: logger.debug(" + Move to {} neighbours from node {}".format(len(nbrs), m.edge_m.l1)) logger.debug(m.repr_header()) for nbr_label, nbr_loc in nbrs: # === Move from node to node (or stay on node) === if not self.only_edges: edge_m = Segment(nbr_label, nbr_loc) edge_o = Segment(f"O{obs_idx}", self.path[obs_idx]) m_next = m.next(edge_m, edge_o, obs=obs_idx) if m_next is not None: if inc_delayed: m_next.delayed += 1 self._insert(m_next) if __debug__: logger.debug(str(m_next)) # === Move from node to edge === if m.edge_m.l1 != nbr_label: edge_m = Segment(m.edge_m.l1, m.edge_m.p1, nbr_label, nbr_loc) edge_o = Segment(f"O{obs_idx}", self.path[obs_idx]) m_next = m.next(edge_m, edge_o, obs=obs_idx) if m_next is not None: if inc_delayed: m_next.delayed += 1 self._insert(m_next) if __debug__: logger.debug(str(m_next)) else: if __debug__: logger.debug(self.matching.repr_static(('x', f'{nbr_label}-{nbr_label} < self-loop'))) else: # == Move to neighbour from edge == if __debug__: logger.debug(" + Move to neighbour from edge {}".format(m.label)) logger.debug(m.repr_header()) # === Stay on edge === edge_m = Segment(m.edge_m.l1, m.edge_m.p1, m.edge_m.l2, m.edge_m.p2) edge_o = Segment(f"O{obs_idx}", self.path[obs_idx]) m_next = m.next(edge_m, edge_o, obs=obs_idx) if m_next is not None: if inc_delayed: m_next.delayed += 1 self._insert(m_next) if __debug__: logger.debug(str(m_next)) # === Move from edge to node === if not self.only_edges: edge_m = Segment(m.edge_m.l2, m.edge_m.p2) edge_o = Segment(f"O{obs_idx}", self.path[obs_idx]) m_next = m.next(edge_m, edge_o, obs=obs_idx) if m_next is not None: if inc_delayed: m_next.delayed += 1 self._insert(m_next) if __debug__: logger.debug(str(m_next)) else: # === Move from edge to next edge === if max_dist is None: nbrs = self.map.edges_nbrto((m.edge_m.l1, m.edge_m.l2)) # type: list else: nbrs = [(l1, p1, l2, p2) for _, l1, p1, l2, p2, _, _ in self.map.edges_closeto(m.edge_m.pi, max_dist=max_dist)] if nbrs is None or len(nbrs) == 0: if __debug__: logger.debug(f"No neighbours found for edge {m.edge_m.label}") continue for nbr_label1, nbr_loc1, nbr_label2, nbr_loc2 in nbrs: # same edge is different action, opposite edge should be allowed to return in a one-way street if m.edge_m.l2 != nbr_label2 and m.edge_m.l1 != nbr_label1: edge_m = Segment(nbr_label1, nbr_loc1, nbr_label2, nbr_loc2) edge_o = Segment(f"O{obs_idx}", self.path[obs_idx]) m_next = m.next(edge_m, edge_o, obs=obs_idx) if m_next is not None: if inc_delayed: m_next.delayed += 1 self._insert(m_next) if __debug__: mstr = str(m_next) logger.debug(mstr) if self.max_lattice_width: self.lattice[obs_idx].prune(0, self.max_lattice_width, self.expand_now) if count == 0: if __debug__: logger.debug("No active solution found anymore") return False return True def _match_non_emitting_states(self, obs_idx, expand=False): """Match sequences of nodes that all refer to the same observation at obs_idx. Assumptions: This method assumes that the lattice is filled up for both obs_idx and obs_idx + 1. :param obs_idx: Index of the first observation used (the second will be obs_idx + 1) :return: None """ obs = self.path[obs_idx] if obs_idx < len(self.path) - 1: obs_next = self.path[obs_idx + 1] else: obs_next = None # The current states are the current observation's states if expand: cur_lattice = dict((m.key, m) for m in self.lattice[obs_idx].values(0) if not m.stop and m.delayed == self.expand_now) else: cur_lattice = dict((m.key, m) for m in self.lattice[obs_idx].values(0) if not (m.stop or m.delayed > 0)) lattice_toinsert = list() # The current best states are the next observation's states if you would ignore non-emitting states lattice_best = dict((m.shortkey, m) for m in self.lattice[obs_idx + 1].values(0) if not m.stop) lattice_ne = set(m.shortkey for m in self.lattice[obs_idx + 1].values(0) if not m.stop and self._skip_ne_states(m)) # cur_lattice = set(self.lattice[obs_idx].values()) nb_ne = 0 prune_thr = None while len(cur_lattice) > 0 and nb_ne < self.non_emitting_states_maxnb: nb_ne += 1 if __debug__: logger.debug("--- obs {}:{} --- {} - {} ---".format(obs_idx, nb_ne, obs, obs_next)) cur_lattice = self._match_non_emitting_states_inner(cur_lattice, obs_idx, obs, obs_next, nb_ne, lattice_best, lattice_ne) if self.max_lattice_width is not None: self.lattice[obs_idx].prune(nb_ne, self.max_lattice_width, self.expand_now, prune_thr) # Link to next observation self._match_non_emitting_states_end(cur_lattice, obs_idx + 1, obs_next, lattice_best, expand=expand) if self.max_lattice_width is not None: prune_thr = self.lattice[obs_idx + 1].prune(0, self.max_lattice_width, self.expand_now, None) if self.max_lattice_width is not None: self.lattice[obs_idx + 1].prune(0, self.max_lattice_width, self.expand_now, None) # logger.info('Used {} levels of non-emitting states'.format(nb_ne)) # for m in lattice_toinsert: # self._insert(m) def _node_in_prev_ne(self, m_next, label): """Is the given node already visited in the chain of non-emitting states. :param m_next: :param label: Node label :return: True or False """ # for m in itertools.chain(m_next.prev, m_next.prev_other): for m in m_next.prev: # type: BaseMatching if m.obs != m_next.obs: return False assert(m_next.obs_ne != m.obs_ne) # print('prev', m.shortkey, 'checking for ', label) # if label == m.shortkey: if label in m.nodes: return True if m.obs_ne == 0: return False if self._node_in_prev_ne(m, label): return True return False @staticmethod def _insert_tmp(m_next, lattice): if m_next.key in lattice: return lattice[m_next.key].update(m_next) else: lattice[m_next.key] = m_next return True def _match_non_emitting_states_inner(self, cur_lattice, obs_idx, obs, obs_next, nb_ne, lattice_best, lattice_ne): # cur_lattice_new = dict() cur_lattice_new = self.lattice[obs_idx].dict(nb_ne) for m in cur_lattice.values(): # type: BaseMatching if m.stop or m.delayed != self.expand_now: continue if m.shortkey in lattice_ne: logger.debug(f"Skip non-emitting states from {m.label}, already visited") continue # == Move to neighbour edge from edge == if m.edge_m.l2 is not None and self.only_edges: nbrs = self.map.edges_nbrto((m.edge_m.l1, m.edge_m.l2)) # print("Neighbours for {}: {}".format(m, nbrs)) if nbrs is None or len(nbrs) == 0: if __debug__: logger.debug(f"No neighbours found for edge {m.edge_m.label} ({m.label}, non-emitting)") continue for nbr_label1, nbr_loc1, nbr_label2, nbr_loc2 in nbrs: if self._node_in_prev_ne(m, nbr_label2): if __debug__: logger.debug(self.matching.repr_static(('x', '{} < node in prev ne'.format(nbr_label2)))) continue # === Move to next edge === if m.edge_m.l2 != nbr_label2 and m.edge_m.l1 != nbr_label2: edge_m = Segment(nbr_label1, nbr_loc1, nbr_label2, nbr_loc2) edge_o = Segment(f"O{obs_idx}", obs, f"O{obs_idx+1}", obs_next) m_next = m.next(edge_m, edge_o, obs=obs_idx, obs_ne=nb_ne) if m_next is not None: if m_next.key in cur_lattice_new: if m_next.shortkey in lattice_best: if approx_leq(m_next.dist_obs, lattice_best[m_next.shortkey].dist_obs): cur_lattice_new[m_next.key].update(m_next) else: m_next.stop = True if __debug__ and logger.isEnabledFor(logging.DEBUG): logger.debug(f" | Stopped trace: distance larger than best for key {m_next.shortkey}: " f"{m_next.dist_obs} > {lattice_best[m_next.shortkey].dist_obs}") else: cur_lattice_new[m_next.key].update(m_next) else: if m_next.shortkey in lattice_best: # if m_next.logprob > lattice_best[m_next.shortkey].logprob: if approx_leq(m_next.dist_obs, lattice_best[m_next.shortkey].dist_obs): cur_lattice_new[m_next.key] = m_next # lattice_best[m_next.shortkey] = m_next # lattice_toinsert.append(m_next) else: if __debug__ and logger.isEnabledFor(logging.DEBUG): logger.debug(f" | Stopped trace: distance larger than best for key {m_next.shortkey}: " f"{m_next.dist_obs} > {lattice_best[m_next.shortkey].dist_obs}") m_next.stop = True else: cur_lattice_new[m_next.key] = m_next # lattice_best[m_next.shortkey] = m_next # lattice_toinsert.append(m_next) # cur_lattice_new.add(m_next) if __debug__: logger.debug(str(m_next)) else: if __debug__: logger.debug(self.matching.repr_static(('x', f'{nbr_label1}-{nbr_label2} < goes back (ne)'))) # == Move to neighbour node from node== if m.edge_m.l2 is None and not self.only_edges: cur_node = m.edge_m.l1 nbrs = self.map.nodes_nbrto(cur_node) if nbrs is None: if __debug__: logger.debug( f"No neighbours found for node {cur_node} ({m.label}, non-emitting)") continue if __debug__: logger.debug( f" + Move to {len(nbrs)} neighbours from node {cur_node} ({m.label}, non-emitting)") logger.debug(m.repr_header()) for nbr_label, nbr_loc in nbrs: # print(f"self._node_in_prev_ne({m.label}, {nbr_label}) = {self._node_in_prev_ne(m, nbr_label)}") if self._node_in_prev_ne(m, nbr_label): if __debug__: logger.debug(self.matching.repr_static(('x', '{} < node in prev ne'.format(nbr_label)))) continue # === Move to next node === if m.edge_m.l1 != nbr_label: edge_m = Segment(nbr_label, nbr_loc) edge_o = Segment(f"O{obs_idx}", obs, f"O{obs_idx+1}", obs_next) m_next = m.next(edge_m, edge_o, obs=obs_idx, obs_ne=nb_ne) if m_next is not None: if m_next.key in cur_lattice_new: cur_lattice_new[m_next.key].update(m_next) else: if m_next.shortkey in lattice_best: # if m_next.logprob > lattice_best[m_next.shortkey].logprob: if m_next.dist_obs < lattice_best[m_next.shortkey].dist_obs: cur_lattice_new[m_next.key] = m_next lattice_best[m_next.shortkey] = m_next # lattice_toinsert.append(m_next) elif __debug__ and logger.isEnabledFor(logging.DEBUG): m_next.stop = True cur_lattice_new[m_next.key] = m_next # lattice_toinsert.append(m_next) else: cur_lattice_new[m_next.key] = m_next lattice_best[m_next.shortkey] = m_next # lattice_toinsert.append(m_next) # cur_lattice_new.add(m_next) if __debug__: logger.debug(str(m_next)) else: if __debug__: logger.debug(f"x | {m.edge_m.l1}-{nbr_label} < self-loop") return cur_lattice_new def _match_non_emitting_states_end(self, cur_lattice, obs_idx, obs_next, lattice_best, expand=False): for m in cur_lattice.values(): # type: BaseMatching if m.stop or m.delayed > self.expand_now: continue if m.edge_m.l2 is not None: # Move to neighbour edge from edge nbrs = self.map.edges_nbrto((m.edge_m.l1, m.edge_m.l2)) # print("Neighbours for {}: {}".format(m, nbrs)) if nbrs is None or len(nbrs) == 0: if __debug__: logger.debug("No neighbours found for edge {} ({})".format(m.edge_m.label, m.label)) continue if __debug__: logger.debug(f" + Move to {len(nbrs)} neighbours from edge {m.edge_m.label} " f"({m.label}, non-emitting->emitting)") logger.debug(m.repr_header()) for nbr_label1, nbr_loc1, nbr_label2, nbr_loc2 in nbrs: if self._node_in_prev_ne(m, nbr_label2): if __debug__: logger.debug(self.matching.repr_static(('x', '{} < node in prev ne'.format(nbr_label2)))) continue # Move to next edge if m.edge_m.l1 != nbr_label2 and m.edge_m.l2 != nbr_label2: edge_m = Segment(nbr_label1, nbr_loc1, nbr_label2, nbr_loc2) edge_o = Segment(f"O{obs_idx+1}", obs_next) m_next = m.next(edge_m, edge_o, obs=obs_idx) if m_next is not None: if m_next.shortkey in lattice_best: # if m_next.dist_obs < lattice_best[m_next.shortkey].dist_obs: if m_next.logprob > lattice_best[m_next.shortkey].logprob: lattice_best[m_next.shortkey] = m_next # lattice_toinsert.append(m_next) self.lattice[obs_idx].upsert(m_next) elif __debug__ and logger.isEnabledFor(logging.DEBUG): m_next.stop = True # lattice_toinsert.append(m_next) self.lattice[obs_idx].upsert(m_next) else: lattice_best[m_next.shortkey] = m_next # lattice_toinsert.append(m_next) self.lattice[obs_idx].upsert(m_next) if __debug__: logger.debug(str(m_next)) else: if __debug__: logger.debug(self.matching.repr_static(('x', '{} < going back'.format(nbr_label2)))) else: # m.edge_m.l2 is None: # Move to neighbour node from node cur_node = m.edge_m.l1 nbrs = self.map.nodes_nbrto(cur_node) # print("Neighbours for {}: {}".format(m, nbrs)) if nbrs is None: if __debug__: logger.debug("No neighbours found for node {}".format(cur_node, m.label)) continue if __debug__: logger.debug(f" + Move to {len(nbrs)} neighbours from node {cur_node} " f"({m.label}, non-emitting->emitting)") logger.debug(m.repr_header()) for nbr_label, nbr_loc in nbrs: if self._node_in_prev_ne(m, nbr_label): if __debug__: logger.debug(self.matching.repr_static(('x', '{} < node in prev ne'.format(nbr_label)))) continue # Move to next node if m.edge_m.l1 != nbr_label: # edge_m = Segment(m.edge_m.l1, m.edge_m.p1, nbr_label, nbr_loc) edge_m = Segment(nbr_label, nbr_loc) edge_o = Segment(f"O{obs_idx+1}", obs_next) m_next = m.next(edge_m, edge_o, obs=obs_idx) if m_next is not None: if m_next.shortkey in lattice_best: # if m_next.dist_obs < lattice_best[m_next.shortkey].dist_obs: if m_next.logprob > lattice_best[m_next.shortkey].logprob: lattice_best[m_next.shortkey] = m_next # lattice_toinsert.append(m_next) self.lattice[obs_idx].upsert(m_next) elif __debug__ and logger.isEnabledFor(logging.DEBUG): m_next.stop = True # lattice_toinsert.append(m_next) self.lattice[obs_idx].upsert(m_next) else: lattice_best[m_next.shortkey] = m_next # lattice_toinsert.append(m_next) self.lattice[obs_idx].upsert(m_next) if __debug__: logger.debug(str(m_next)) else: if __debug__: logger.debug(self.matching.repr_static(('x', '{} < self-loop'.format(nbr_label)))) def get_matching(self, identifier=None): m = None # type: Optional[BaseMatching] if isinstance(identifier, BaseMatching): m = identifier elif identifier is None: col = self.lattice[len(self.lattice) - 1] for curm in col.values_all(): if m is None or curm.logprob > m.logprob: m = curm elif type(identifier) is int: # If integer, search for the best matching at this index in the lattice for cur_m in self.lattice[identifier].values_all(): # type:BaseMatching if not cur_m.stop and (m is None or cur_m.logprob > m.logprob): m = cur_m elif type(identifier) is str: # If string, try to parse identifier parts = identifier.split('-') idx, ne, key = None, None, None if len(parts) == 4: nodea, nodeb, idx, ne = [int(part) for part in parts] key = (nodea, nodeb, idx, ne) col = self.lattice[idx] # type: LatticeColumn col_ne = col.o[ne] m = col_ne[key] elif len(parts) == 3: node, idx, ne = [int(part) for part in parts] key = (node, idx, ne) col = self.lattice[idx] # type: LatticeColumn col_ne = col.o[ne] m = col_ne[key] elif len(parts) == 1: m = None l1 = int(parts[0]) for l in self.lattice.values(): # type: LatticeColumn for curm in l.values_all(): if (curm.edge_m.l1 == l1 or curm.edge_m.l2 == l1) and \ (m is None or curm.logprob > m.logprob): m = curm else: raise AttributeError(f'Unknown string format for matching. ' 'Expects <node>-<idx>-<ne> or <node>-<node>-<idx>-<ne>.') return m
[docs] def get_matching_path(self, start_m): """List of Matching objects that end in the given Matching object.""" start_m = self.get_matching(start_m) return self._build_matching_path(start_m)
[docs] def get_node_path(self, start_m, only_nodes=False): """List of node/edge names that end in the given Matching object.""" path = self.get_matching_path(start_m) node_path = [m.shortkey for m in path] if only_nodes: node_path = self.node_path_to_only_nodes(node_path) return node_path
[docs] def get_path(self, only_nodes=True, allow_jumps=False, only_closest=True): """A list with all the nodes (no edges) the matched path passes through.""" if only_nodes is False: return self.node_path if self.node_path is None or len(self.node_path) == 0: return [] path = self.node_path_to_only_nodes(self.node_path, allow_jumps=allow_jumps) if only_closest: m = self.lattice_best[0] if m.edge_m.ti > 0.5: path.pop(0) return path
[docs] def node_path_to_only_nodes(self, path, allow_jumps=False): """Path of nodes and edges to only nodes. :param path: List of node names or edges as (node name, node name) :param allow_jumps: Allow a path over edges that are not connected. This occurs when matches are added without an edge, for example, when searching for edges in the distance neighborhood instead in the graph. :return: List of node names """ nodes = [] prev_state = path[0] if type(prev_state) is tuple: nodes.append(prev_state[0]) nodes.append(prev_state[1]) prev_node = prev_state[1] else: nodes.append(prev_state) prev_node = prev_state for state in path[1:]: if state == prev_state: continue if type(state) is not tuple: if state != prev_node: nodes.append(state) prev_node = state elif type(state) is tuple: if state[0] == prev_node: if state[1] != prev_node: nodes.append(state[1]) prev_node = state[1] elif state[1] == prev_node: if state[0] != prev_node: nodes.append(state[0]) prev_node = state[0] elif not allow_jumps: raise Exception(f"State {state} does not have as previous node {prev_node}") else: nodes.append(state[0]) nodes.append(state[1]) prev_node = state[1] else: raise Exception(f"Unknown type of state: {state} ({type(state)})") prev_state = state return nodes
def _build_matching_path(self, start_m, max_depth=None): lattice_best = [] node_max = start_m cur_depth = 0 if __debug__ and logger.isEnabledFor(logging.DEBUG): logger.debug(self.matching.repr_header(stop=" ")) logger.debug("Start ({}): {}".format(node_max.obs, node_max)) lattice_best.append(node_max) if node_max.is_emitting(): cur_depth += 1 # for obs_idx in reversed(range(start_idx)): if max_depth is None: max_depth = len(self.lattice) + 1 while cur_depth < max_depth and len(node_max.prev) > 0: node_max_last = node_max node_max: Optional[BaseMatching] = None for prev_m in node_max_last.prev: if prev_m is not None and (node_max is None or prev_m.logprob > node_max.logprob): node_max = prev_m if node_max is None: logger.error("Did not find a matching node for path point at index {}. ".format(node_max_last.obs) + "Stopped building path.") break logger.debug("Max ({}): {}".format(node_max.obs, node_max)) lattice_best.append(node_max) if node_max.is_emitting(): cur_depth += 1 lattice_best = list(reversed(lattice_best)) return lattice_best def _build_node_path(self, start_idx, unique=True, max_depth=None, last_is_e=False): """Build the path from the lattice. :param start_idx: :param unique: :param max_depth: :param last_is_e: Last matched lattice node should be an emitting state. In case the matching stops early, the longest path can be in between two observations and thus be a nonemitting state (which by definition has a lower probability than the last emitting state). If this argument is set to true, the longer match is preferred. :return: """ node_max = None node_max_ne = 0 if last_is_e: for m in self.lattice[start_idx].values_all(): # type:BaseMatching if not m.stop and (node_max is None or m.logprob > node_max.logprob): node_max = m else: for m in self.lattice[start_idx].values_all(): # type:BaseMatching if not m.stop and (node_max is None or m.obs_ne > node_max_ne or m.logprob > node_max.logprob): node_max_ne = m.obs_ne node_max = m if node_max is None: logger.error("Did not find a matching node for path point at index {}".format(start_idx)) return None self.lattice_best = self._build_matching_path(node_max, max_depth) node_path = [m.shortkey for m in self.lattice_best] if unique: self.node_path = [] prev_node = None for node in node_path: if node != prev_node: self.node_path.append(node) prev_node = node else: self.node_path = node_path return self.node_path def increase_max_lattice_width(self, max_lattice_width, unique=False, tqdm=None): self.max_lattice_width = max_lattice_width return self.match(self.path, unique=unique, tqdm=tqdm, expand=True)
[docs] def continue_with_distance(self, from_matches=None, k=2, nb_obs=2, max_dist=None): """Continue the matcher but ignore edges and allow jumps to nearby edged. :param from_matches: Search in the neigborhood of these matches :param k: If from_matches is not given, the k best matches are used in the last nb_obs observations since last early_stop_idx :praram nb_obs: If from_matches is not given, the k best matches are used in the last nb_obs observations since last early_stop_idx :param max_dist: Add edges that are maximally max_dist away from the previous match. If none, self.max_dist * 3 is used. """ if from_matches is None: from_matches = self.best_last_matches(k=k, nb_obs=nb_obs) self.increase_delayed() if max_dist is None: max_dist = self.max_dist * 3 for obs_idx, cur_matches in from_matches.items(): self._match_states(obs_idx, prev_lattice=cur_matches, max_dist=max_dist, inc_delayed=True)
[docs] def path_bb(self): """Get boundig box of matched path (if it exists, otherwise return None).""" path = self.path plat, plon = islice(zip(*path), 2) lat_min, lat_max = min(plat), max(plat) lon_min, lon_max = min(plon), max(plon) bb = lat_min, lon_min, lat_max, lon_max return bb
def print_lattice(self, file=None, obs_idx=None, obs_ne=0, label_width=None, debug=False): if debug: xprint = logger.debug else: if file is None: file = sys.stdout xprint = lambda arg: print(arg, file=file) # print("Lattice:", file=file) if obs_idx is not None: idxs = [obs_idx] else: idxs = range(len(self.lattice)) for idx in idxs: if len(self.lattice[idx]) > 0: if label_width is None: label_width = 0 for m in self.lattice[idx].values(obs_ne): label_width = max(label_width, len(str(m.label))) xprint("--- obs {} ---".format(idx)) xprint(self.matching.repr_header(label_width=label_width)) for m in sorted(self.lattice[idx].values(obs_ne), key=lambda t: str(t.label)): xprint(m.__str__(label_width=label_width))
[docs] def lattice_dot(self, file=None, precision=None, render=False): """Write the lattice as a Graphviz DOT file. :param file: File object to print to. Prints to stdout if None. :param precision: Precision of (log) probabilities. :param render: Try to render the generated Graphviz file. """ if file is None: file = sys.stdout if precision is None: prfmt = '' else: prfmt = f'.{precision}f' print('digraph lattice {', file=file) print('\trankdir=LR;', file=file) # Vertices for idx_ob in range(len(self.lattice)): col = self.lattice[idx_ob] for idx_ne in range(len(col)): ms = col.values(idx_ne) if len(ms) == 0: continue cnames = [(m.obs_ne, m.cname, m.stop, m.delayed) for m in ms] cnames.sort() cur_obs_ne = -1 print('\t{\n\t\trank=same; ', file=file) for obs_ne, cname, stop, delayed in cnames: if obs_ne != cur_obs_ne: if cur_obs_ne != -1: print('\t};\n\t{\n\t\trank=same; ', file=file) cur_obs_ne = obs_ne if stop: options = 'label="{} x",color=gray,fontcolor=gray'.format(cname) elif delayed > self.expand_now: options = 'label="{} d{}",color=gray,fontcolor=gray'.format(cname, delayed) elif self.expand_now != 0: options = 'label="{} d{}"'.format(cname, delayed) else: options = 'label="{} "'.format(cname) print('\t\t{} [{}];'.format(cname, options), file=file) print('\t};', file=file) # Edges for idx_ob in range(len(self.lattice)): col = self.lattice[idx_ob] for idx_ne in range(len(col)): ms = col.values(idx_ne) if len(ms) == 0: continue for m in ms: for mp in m.prev: if m.stop or m.delayed > self.expand_now: options = ',color=gray,fontcolor=gray' else: options = '' print(f'\t {mp.cname} -> {m.cname} [label="{m.logprob:{prfmt}}"{options}];', file=file) for mp in m.prev_other: if m.stop or m.delayed > self.expand_now: options = ',color=gray,fontcolor=gray' else: options = '' print(f'\t {mp.cname} -> {m.cname} [color=gray,label="{m.logprob:{prfmt}}"{options}];', file=file) print('}', file=file) if render and file is not None: import subprocess as sp from pathlib import Path from io import TextIOWrapper if isinstance(file, Path): fn = str(file.canonical()) elif isinstance(file, TextIOWrapper): file.flush() fn = file.name else: fn = str(file) cmd = ['dot', '-Tpdf', '-O', fn] logger.debug(' '.join(cmd)) sp.call(cmd)
def print_lattice_stats(self, file=None, verbose=False): if file is None: file = sys.stdout print("Stats lattice", file=file) print("-------------", file=file) stats = OrderedDict() stats["nbr levels"] = len(self.lattice) if self.lattice else "?" total_nodes = 0 max_nodes = 0 min_nodes = 9999999 if self.lattice: sizes = [] for idx in range(len(self.lattice)): level = self.lattice[idx].values(0) # stats["#nodes[{}]".format(idx)] = len(level) sizes.append(len(level)) total_nodes += len(level) if len(level) < min_nodes: min_nodes = len(level) if len(level) > max_nodes: max_nodes = len(level) stats["nbr lattice"] = total_nodes if verbose: stats["nbr lattice[level]"] = ", ".join([str(s) for s in sizes]) stats["avg lattice[level]"] = total_nodes/len(self.lattice) stats["min lattice[level]"] = min_nodes stats["max lattice[level]"] = max_nodes if self.lattice_best and len(self.lattice_best) > 0: stats["avg obs distance"] = np.mean([m.dist_obs for m in self.lattice_best]) stats["last logprob"] = self.lattice_best[-1].logprob stats["last length"] = self.lattice_best[-1].length stats["last norm logprob"] = self.lattice_best[-1].logprob / self.lattice_best[-1].length if verbose: stats["best logprob"] = ", ".join(["{:.3f}".format(m.logprob) for m in self.lattice_best]) stats["best norm logprob"] = \ ", ".join(["{:.3f}".format(m.logprob/m.length) for i, m in enumerate(self.lattice_best)]) stats["best norm prob"] = \ ", ".join(["{:.3f}".format(math.exp(m.logprob/m.length)) for i, m in enumerate(self.lattice_best)]) for key, val in stats.items(): print("{:<24} : {}".format(key, val), file=file) def node_counts(self): if self.lattice is None: return None counts = defaultdict(lambda: 0) for level in self.lattice.values(): for m in level.values_all(): counts[m.label] += 1 return counts
[docs] def inspect_early_stopping(self): """Analyze the lattice and try to find most plausible reason why the matching stopped early and print to stdout.""" if self.early_stop_idx is None: print("No early stopping.") return col = self.lattice[self.early_stop_idx - 1] print("The last matched nodes or edges were:") first_row = True ignore = set() for ne_i in range(len(col.o) - 1, -1, -1): for v in col.o[ne_i].values(): if v.key not in ignore: if first_row: print(v.repr_header()) first_row = False print(v) ignore.update(r.key for r in v.prev)
[docs] def best_last_matches(self, k=1, nb_obs=3): """Return the k best last matches. :param k: Number of best matches to keep for an observation :param nb_obs: How many last matched observations to consider """ import heapq if self.early_stop_idx is None: col_idx = len(self.lattice) - 1 else: col_idx = self.early_stop_idx - 1 hh = [] obs_cnt = 0 while col_idx >= 0 and obs_cnt < nb_obs: h = [] col = self.lattice[col_idx] col_oneselected = False for ne_i in range(len(col.o) - 1, -1, -1): for v in col.o[ne_i].values(): if v.stop: continue if len(h) < k: heapq.heappush(h, (v.logprob, v)) col_oneselected = True elif v.logprob > h[0][0]: heapq.heappop(h) heapq.heappush(h, (v.logprob, v)) col_oneselected = True hh.extend(h) if col_oneselected is False: print(f'break in {col_idx=}') break col_idx -= 1 obs_cnt += 1 result = defaultdict(list) for m in hh: m = m[1] result[m.obs + 1].append(m) # return [m[1] for m in hh] return result
[docs] def copy_lastinterface(self, nb_interfaces=1): """Copy the current matcher and keep the last interface as the start point. This method allows you to perform incremental matching without keeping the entire lattice in memory. You need to run :meth:`match_incremental` on this object to continue from the existing (partial) lattice. Otherwise, if you use :meth:`match`, it will be overwritten. Open question, if there is no need to keep track of older lattices, it will probably be more efficient to clear the older parts of the interface instead of copying the newer parts. :param nb_interfaces: Nb of interfaces (columns in lattice) to keep. Default is 1, the last one. :return: new Matcher object """ matcher = self.__class__(self.map, obs_noise=self.obs_noise, max_dist_init=self.max_dist_init, max_dist=self.max_dist, min_prob_norm=self.min_logprob_norm, non_emitting_states=self.non_emitting_states, max_lattice_width=self.max_lattice_width, only_edges=self.only_edges, obs_noise_ne=self.obs_noise_ne, matching=self.matching, avoid_goingback=self.avoid_goingback, non_emitting_length_factor=math.exp(self.ne_length_factor_log)) matcher.lattice = [] matcher.path = [] for int_i in range(len(self.lattice) - nb_interfaces, len(self.lattice)): matcher.lattice.append(self.lattice[int_i]) matcher.path.append(self.path[int_i]) return matcher
@property def path_pred(self): """The matched path, both nodes and/or edges (depending on your settings).""" return self.node_path @property def path_pred_onlynodes(self): """A list with all the nodes (no edges) the matched path passes through.""" return self.get_path(only_nodes=True, allow_jumps=False) @property def path_pred_onlynodes_withjumps(self): """A list with all the nodes (no edges) the matched path passes through.""" return self.get_path(only_nodes=True, allow_jumps=True)
[docs] def path_pred_distance(self): """Total distance of the matched path.""" if self.lattice_best is None: return None if len(self.lattice_best) == 1: return 0 dist = 0 m_prev = self.lattice_best[0] for idx, m in enumerate(self.lattice_best[1:]): if m_prev.edge_m.label != m.edge_m.label and m_prev.edge_m.l2 == m.edge_m.l1: # Go over the connection between two edges to compute the distance cdist = self.map.distance(m_prev.edge_m.pi, m_prev.edge_m.p2) cdist += self.map.distance(m_prev.edge_m.p2, m.edge_m.pi) else: cdist = self.map.distance(m_prev.edge_m.pi, m.edge_m.pi) dist += cdist m_prev = m return dist
[docs] def path_distance(self): """Total distance of the observations.""" if self.lattice_best is None: return None if len(self.lattice_best) == 1: return 0 dist = 0 m_prev = self.lattice_best[0] for m in self.lattice_best[1:]: dist += self.map.distance(m_prev.edge_o.pi, m.edge_o.pi) m_prev = m return dist
[docs] def path_all_distances(self): """Return a list of all distances between observed trace and map. One entry for each point in the map and point in the trace that are mapped to each other. In case non-emitting nodes are used, extra entries can be present where a point in the trace or a point in the map is mapped to a segment. """ path = self.lattice_best dists = [m.dist_obs for m in path] return dists