Source code for semra.inference

"""Inference functionality for SeMRA."""

from __future__ import annotations

import itertools as itt
import typing as t
from collections import Counter, defaultdict
from collections.abc import Iterable

import bioregistry
import networkx as nx
from pydantic import BaseModel
from tqdm.asyncio import tqdm

from semra.api import assemble_evidences, flip
from semra.io.graph import MULTIDIGRAPH_DATA_KEY, to_multidigraph
from semra.rules import FLIP, GENERALIZATIONS
from semra.struct import Evidence, Mapping, ReasonedEvidence, Reference
from semra.utils import cleanup_prefixes, semra_tqdm
from semra.vocabulary import (
    BROAD_MATCH,
    CHAIN_MAPPING,
    DB_XREF,
    EXACT_MATCH,
    KNOWLEDGE_MAPPING,
    NARROW_MATCH,
)

__all__ = [
    "infer_chains",
    "infer_dbxref_mutations",
    "infer_generalizations",
    "infer_mutations",
    "infer_mutual_dbxref_mutations",
    "infer_reversible",
]


[docs] def infer_reversible(mappings: t.Iterable[Mapping], *, progress: bool = True) -> list[Mapping]: """Extend the mapping list with flipped mappings. :param mappings: An iterable of mappings :param progress: Should a progress bar be shown? Defaults to true. :returns: A list where if a mapping can be flipped (i.e., :func:`flip`), a flipped mapping is added. Flipped mappings contain reasoned evidence :class:`ReasonedEvidence` objects that point to the mapping from which the evidence was derived. Flipping a mapping means switching the subject and object, then modifying the predicate as follows: 1. Broad becomes narrow 2. Narrow becomes broad 3. Exact and close mappings remain the same, since they're reflexive This is configured in the :data:`semra.rules.FLIP` dictionary. >>> from semra import Mapping, Reference, EXACT_MATCH, SimpleEvidence >>> from semra.api import get_test_evidence, get_test_reference >>> r1, r2 = get_test_reference(2) >>> e1 = get_test_evidence() >>> m1 = Mapping(subject=r1, predicate=EXACT_MATCH, object=r2, evidence=[e1]) >>> mappings = infer_reversible([m1], progress=False) >>> len(mappings) 2 >>> assert mappings[0] == m1 .. warning:: This operation does not "assemble", meaning if you had existing evidence for an inverse mapping, they will be seperate. Therefore, you can chain it with the :func:`semra.api.assemble_evidences` operation: >>> from semra import Mapping, Reference, EXACT_MATCH >>> from semra.api import get_test_evidence >>> from semra.api import get_test_evidence, get_test_reference >>> r1, r2 = get_test_reference(2) >>> e1, e2 = get_test_evidence(2) >>> m1 = Mapping(subject=r1, predicate=EXACT_MATCH, object=r2, evidence=[e1]) >>> m2 = Mapping(subject=r2, predicate=EXACT_MATCH, object=r1, evidence=[e2]) >>> mappings = infer_reversible([m1, m2]) >>> len(mappings) 4 >>> mappings = assemble_evidences(mappings) >>> len(mappings) 2 """ rv = [] for mapping in semra_tqdm(mappings, desc="Infer reverse", progress=progress): rv.append(mapping) if flipped_mapping := flip(mapping): rv.append(flipped_mapping) return rv
[docs] def infer_chains( mappings: list[Mapping], *, backwards: bool = True, progress: bool = True, cutoff: int = 5, minimum_component_size: int = 2, maximum_component_size: int = 100, ) -> list[Mapping]: """Apply graph-based reasoning over mapping chains to infer new mappings. :param mappings: A list of input mappings :param backwards: Should inference be done in reverse? :param progress: Should a progress bar be shown? Defaults to true. :param cutoff: What's the maximum length path to infer over? :param minimum_component_size: The smallest size of a component to consider, defaults to 2 :param maximum_component_size: The smallest size of a component to consider, defaults to 100. Components that are very large (i.e., much larger than the number of target prefixes) likely are the result of many broad/narrow mappings :return: The list of input mappings _plus_ inferred mappings """ mappings = assemble_evidences(mappings, progress=progress) graph = to_multidigraph(mappings) new_mappings = [] components = sorted( ( component for component in nx.weakly_connected_components(graph) if minimum_component_size < len(component) <= maximum_component_size ), key=len, reverse=True, ) it = tqdm( components, unit="component", desc="Inferring chains", unit_scale=True, disable=not progress ) for _i, component in enumerate(it): sg: nx.MultiDiGraph = graph.subgraph(component).copy() sg_len = sg.number_of_nodes() it.set_postfix(size=sg_len) inner_it = tqdm( itt.combinations(sg, 2), total=sg_len * (sg_len - 1) // 2, unit_scale=True, disable=not progress, unit="edge", leave=False, ) for s, o in inner_it: if sg.has_edge(s, o): # do not overwrite existing mappings continue # TODO there has to be a way to reimplement transitive closure to handle this # nx.shortest_path(sg, s, o) predicate_evidence_dict: defaultdict[Reference, list[Evidence]] = defaultdict(list) for path in nx.all_simple_edge_paths(sg, s, o, cutoff=cutoff): if _path_has_prefix_duplicates(path): continue predicates = [k for _u, _v, k in path] p = _reason_multiple_predicates(predicates) if p is not None: evidence = ReasonedEvidence( justification=CHAIN_MAPPING, mappings=[ Mapping( subject=path_s, object=path_o, predicate=path_p, evidence=graph[path_s][path_o][path_p][MULTIDIGRAPH_DATA_KEY], ) for path_s, path_o, path_p in path ], # TODO add confidence that's inversely proportional to sg_len, i.e. # larger components should return less confident mappings ) predicate_evidence_dict[p].append(evidence) for p, evidences in predicate_evidence_dict.items(): new_mappings.append(Mapping(subject=s, predicate=p, object=o, evidence=evidences)) if backwards: new_mappings.append( Mapping(object=s, subject=o, predicate=FLIP[p], evidence=evidences) ) return [*mappings, *new_mappings]
def _reason_multiple_predicates(predicates: t.Iterable[Reference]) -> Reference | None: """Return a single reasoned predicate based on a set, if possible. :param predicates: A collection of predicates :return: A single predicate that represents the set, if possible For example, if a predicate set with exact + broad are given, then the most specific possible is exact. If a predicate contains exact, broad, and narrow, then no reasoning can be done and None is returned. """ predicate_set = set(predicates) if predicate_set == {EXACT_MATCH}: return EXACT_MATCH if predicate_set == {BROAD_MATCH} or predicate_set == {EXACT_MATCH, BROAD_MATCH}: return BROAD_MATCH if predicate_set == {NARROW_MATCH} or predicate_set == {EXACT_MATCH, NARROW_MATCH}: return NARROW_MATCH return None def _path_has_prefix_duplicates(path: Iterable[tuple[Reference, Reference, Reference]]) -> bool: """Return if the path has multiple unique.""" elements: set[Reference] = set() for u, v, _ in path: elements.add(u) elements.add(v) counter = Counter(element.prefix for element in elements) return any(v > 1 for v in counter.values())
[docs] def infer_mutual_dbxref_mutations( mappings: Iterable[Mapping], prefixes: Iterable[str], confidence: float | None = None, *, progress: bool = False, ) -> list[Mapping]: """Upgrade database cross-references into exact matches for the given pairs. :param mappings: A list of mappings :param prefixes: A dictionary of source/target prefix pairs to the confidence of upgrading dbxrefs. If giving a collection of pairs, will use the ``confidence`` value as given. :param confidence: The default confidence to be used if ``pairs`` is given as a collection. Defaults to 0.7 :param progress: Should a progress bar be shown? Defaults to true. :return: A new list of mappings containing upgrades In the following example, we use four different terms for *cranioectodermal dysplasia* from the Disease Ontology (DOID), Medical Subject Headings (MeSH), and Unified Medical Language System (UMLS). We use the prior knowledge that there's a high confidence that dbxrefs from DOID to MeSH are actually exact matches. This lets us infer ``m3`` from ``m1``. We don't make any assertions about DOID-UMLS or MeSH-UMLS mappings here, so the example mapping ``m2`` comes along for the ride. >>> from semra import DB_XREF, EXACT_MATCH, Reference, NARROW_MATCH >>> curies = "DOID:0050577", "mesh:C562966", "umls:C4551571" >>> r1, r2, r3 = map(Reference.from_curie, curies) >>> m1 = Mapping.from_triple((r1, DB_XREF, r2)) >>> m2 = Mapping.from_triple((r2, DB_XREF, r3)) >>> m3 = Mapping.from_triple( ... (r1, EXACT_MATCH, r2), ... evidence=[ ... ReasonedEvidence( ... mappings=[m1], justification=KNOWLEDGE_MAPPING, confidence_factor=0.99 ... ) ... ], ... ) # this is what we are inferring >>> assert infer_mutual_dbxref_mutations([m1, m2], ["DOID", "mesh"], confidence=0.99) == [ ... m1, ... m3, ... m2, ... ] This function is a thin wrapper around :func:`infer_mutations` where :data:`semra.DB_XREF` is used as the "old" predicated and :data:`semra.EXACT_MATCH` is used as the "new" predicate. """ prefixes = cleanup_prefixes(prefixes) pairs = { (subject_prefix, object_prefix) for subject_prefix, object_prefix in itt.product(prefixes, repeat=2) if subject_prefix != object_prefix } return infer_dbxref_mutations(mappings, pairs=pairs, confidence=confidence, progress=progress)
[docs] def infer_dbxref_mutations( mappings: Iterable[Mapping], pairs: dict[tuple[str, str], float] | Iterable[tuple[str, str]], confidence: float | None = None, progress: bool = False, ) -> list[Mapping]: """Upgrade database cross-references into exact matches for the given pairs. :param mappings: A list of mappings :param pairs: A dictionary of source/target prefix pairs to the confidence of upgrading dbxrefs. If giving a collection of pairs, will use the ``confidence`` value as given. :param confidence: The default confidence to be used if ``pairs`` is given as a collection. Defaults to 0.7 :param progress: Should a progress bar be shown? Defaults to true. :return: A new list of mappings containing upgrades In the following example, we use four different terms for *cranioectodermal dysplasia* from the Disease Ontology (DOID), Medical Subject Headings (MeSH), and Unified Medical Language System (UMLS). We use the prior knowledge that there's a high confidence that dbxrefs from DOID to MeSH are actually exact matches. This lets us infer ``m3`` from ``m1``. We don't make any assertions about DOID-UMLS or MeSH-UMLS mappings here, so the example mapping ``m2`` comes along for the ride. >>> from semra import DB_XREF, EXACT_MATCH, Reference, NARROW_MATCH >>> curies = "DOID:0050577", "mesh:C562966", "umls:C4551571" >>> r1, r2, r3 = (Reference.from_curie(c) for c in curies) >>> m1 = Mapping.from_triple((r1, DB_XREF, r2)) >>> m2 = Mapping.from_triple((r2, DB_XREF, r3)) >>> mappings = [m1, m2] >>> pairs = {("DOID", "mesh"): 0.99} >>> m3 = Mapping.from_triple( ... (r1, EXACT_MATCH, r2), ... evidence=[ ... ReasonedEvidence( ... mappings=[m1], justification=KNOWLEDGE_MAPPING, confidence_factor=0.99 ... ) ... ], ... ) # this is what we are inferring >>> assert infer_dbxref_mutations(mappings, pairs) == [m1, m3, m2] This function is a thin wrapper around :func:`infer_mutations` where :data:`semra.DB_XREF` is used as the "old" predicated and :data:`semra.EXACT_MATCH` is used as the "new" predicate. """ if confidence is None: confidence = 0.7 if not isinstance(pairs, dict): pairs = dict.fromkeys(pairs, confidence) return infer_mutations( mappings, pairs=pairs, old_predicate=DB_XREF, new_predicate=EXACT_MATCH, progress=progress, )
[docs] def infer_mutations( mappings: Iterable[Mapping], pairs: dict[tuple[str, str], float], old_predicate: Reference, new_predicate: Reference, *, progress: bool = False, ) -> list[Mapping]: """Infer mappings with alternate predicates for the given prefix pairs. :param mappings: Mappings to infer from :param pairs: A dictionary of pairs of (subject prefix, object prefix) to the confidence of inference :param old_predicate: The predicate on which inference should be done :param new_predicate: The predicate to get inferred :param progress: Should a progress bar be shown? Defaults to true. :returns: A list of all old mapping plus inferred ones interspersed. In the following example, we use three different terms for *cranioectodermal dysplasia* from the Disease Ontology (DOID), Medical Subject Headings (MeSH), and Unified Medical Language System (UMLS). We use the prior knowledge that there's a high confidence that dbxrefs from DOID to MeSH are actually exact matches. This lets us infer ``m3`` from ``m1``. We don't make any assertions about DOID-UMLS or MeSH-UMLS mappings here, so the example mapping ``m2`` comes along for the ride. >>> from semra.vocabulary import KNOWLEDGE_MAPPING >>> from semra import DB_XREF, EXACT_MATCH, Reference >>> curies = "DOID:0050577", "mesh:C562966", "umls:C4551571" >>> r1, r2, r3 = (Reference.from_curie(c) for c in curies) >>> m1 = Mapping.from_triple((r1, DB_XREF, r2)) >>> m2 = Mapping.from_triple((r2, DB_XREF, r3)) >>> pairs = {("DOID", "mesh"): 0.99} >>> m3 = Mapping.from_triple( ... (r1, EXACT_MATCH, r2), ... evidence=[ ... ReasonedEvidence( ... mappings=[m1], justification=KNOWLEDGE_MAPPING, confidence_factor=0.99 ... ) ... ], ... ) # this is what we are inferring # this is what we are inferring >>> mappings = infer_mutations([m1, m2], pairs, DB_XREF, EXACT_MATCH) >>> assert mappings == [m1, m3, m2] """ configurations = [ Configuration( old=old_predicate, new=new_predicate, pairs=_clean_pairs(pairs), ) ] return _mutate(mappings, configurations, progress=progress)
class Configuration(BaseModel): """A configuration for mutation.""" old: Reference new: Reference default_confidence: float | None = None pairs: dict[tuple[str, str], float] | None = None def _mutate( mappings: Iterable[Mapping], configurations: list[Configuration], *, progress: bool = False, ) -> list[Mapping]: rv = [] # index all configurations upgrade_map = {c.old: c for c in configurations} for mapping in semra_tqdm(mappings, desc="Adding mutated predicates", progress=progress): rv.append(mapping) configuration = upgrade_map.get(mapping.predicate) if configuration is None: continue confidence_factor: float | None if configuration.default_confidence: confidence_factor = configuration.default_confidence elif configuration.pairs: confidence_factor = configuration.pairs.get( (mapping.subject.prefix, mapping.object.prefix) ) else: raise ValueError if confidence_factor is None: # This means that there was no explicit confidence set for the # subject/object prefix pair, meaning it wasn't asked to be inferred continue inferred_mapping = Mapping( subject=mapping.subject, predicate=configuration.new, object=mapping.object, evidence=[ ReasonedEvidence( justification=KNOWLEDGE_MAPPING, mappings=[mapping], confidence_factor=confidence_factor, ) ], ) rv.append(inferred_mapping) return rv def _clean_pairs(pairs: dict[tuple[str, str], float]) -> dict[tuple[str, str], float]: rv = {} for (p1, p2), v in pairs.items(): p1_norm = bioregistry.normalize_prefix(p1, strict=True) p2_norm = bioregistry.normalize_prefix(p2, strict=True) rv[p1_norm, p2_norm] = v return rv
[docs] def infer_generalizations( mappings: list[Mapping], *, progress: bool = False, ) -> list[Mapping]: """Apply generalization rules. :param mappings: Mappings to process :param progress: Should a progress bar be used? :returns: Mappings that have been mutated to relax relations configured by :data:`semra.rules.GENERALIZATIONS` .. seealso:: Rules definition in SSSOM https://mapping-commons.github.io/sssom/chaining-rules/#generalisation-rules """ configurations = [ Configuration(old=old, new=new, default_confidence=1.0) for old, new in GENERALIZATIONS.items() ] return _mutate(mappings, configurations, progress=progress)