"""Pseudotime inference."""
import numpy as np
import networkx as nx
[docs]
def infer_pseudotime(
adata,
source,
target=None,
nodes_to_include=None,
key="epg",
copy=False,
):
"""Infer pseudotime
Parameters
----------
adata: AnnData
Annotated data matrix.
copy: `bool`, optional (default: False)
If ``True``, return a copy instead of writing to adata.
Returns
-------
"""
epg_edge = adata.uns[key]["edge"]
epg_edge_len = adata.uns[key]["edge_len"]
G = nx.Graph()
edges_weighted = list(zip(epg_edge[:, 0], epg_edge[:, 1], epg_edge_len))
G.add_weighted_edges_from(edges_weighted, weight="len")
if target is not None:
if nodes_to_include is None:
# nodes on the shortest path
nodes_sp = nx.shortest_path(
G, source=source, target=target, weight="len"
)
else:
assert isinstance(
nodes_to_include, list
), "`nodes_to_include` must be list"
# lists of simple paths, in order from shortest to longest
list_paths = list(
nx.shortest_simple_paths(
G, source=source, target=target, weight="len"
)
)
flag_exist = False
for p in list_paths:
if set(nodes_to_include).issubset(p):
nodes_sp = p
flag_exist = True
break
if not flag_exist:
return f"no path that passes {nodes_to_include} exists"
else:
nodes_sp = [source] + [v for u, v in nx.bfs_edges(G, source)]
G_sp = G.subgraph(nodes_sp).copy()
index_nodes = {
x: nodes_sp.index(x) if x in nodes_sp else G.number_of_nodes()
for x in G.nodes
}
if target is None:
dict_dist_to_source = nx.shortest_path_length(
G_sp, source=source, weight="len"
)
else:
dict_dist_to_source = dict(
zip(
nodes_sp,
np.cumsum(
np.array(
[0.0]
+ [
G.get_edge_data(nodes_sp[i], nodes_sp[i + 1])[
"len"
]
for i in range(len(nodes_sp) - 1)
]
)
),
)
)
cells = np.isin(adata.obs[f"{key}_node_id"], nodes_sp)
id_edges_cell = adata.obs.loc[cells, f"{key}_edge_id"].tolist()
edges_cell = adata.uns[key]["edge"][id_edges_cell, :]
len_edges_cell = adata.uns[key]["edge_len"][id_edges_cell]
# proportion on the edge
prop_edge = np.clip(
adata.obs.loc[cells, f"{key}_edge_loc"], a_min=0, a_max=1
).values
dist_to_source = []
for i in np.arange(edges_cell.shape[0]):
if index_nodes[edges_cell[i, 0]] > index_nodes[edges_cell[i, 1]]:
dist_to_source.append(dict_dist_to_source[edges_cell[i, 1]])
prop_edge[i] = 1 - prop_edge[i]
else:
dist_to_source.append(dict_dist_to_source[edges_cell[i, 0]])
dist_to_source = np.array(dist_to_source)
dist_on_edge = len_edges_cell * prop_edge
dist = dist_to_source + dist_on_edge
if copy:
return dist
else:
adata.obs[f"{key}_pseudotime"] = np.nan
adata.obs.loc[cells, f"{key}_pseudotime"] = dist
adata.uns[f"{key}_pseudotime_params"] = {
"source": source,
"target": target,
"nodes_to_include": nodes_to_include,
}