import networkx as nx
from networkx.utils import not_implemented_for

# Authors: Erwan Le Merrer (erwan.lemerrer@technicolor.com)

""" Second order centrality measure."""

__all__ = ["second_order_centrality"]


@not_implemented_for("directed")
def second_order_centrality(G):
    """Compute the second order centrality for nodes of G.

    The second order centrality of a given node is the standard deviation of
    the return times to that node of a perpetual random walk on G:

    Parameters
    ----------
    G : graph
      A NetworkX connected and undirected graph.

    Returns
    -------
    nodes : dictionary
       Dictionary keyed by node with second order centrality as the value.

    Examples
    --------
    >>> G = nx.star_graph(10)
    >>> soc = nx.second_order_centrality(G)
    >>> print(sorted(soc.items(), key=lambda x: x[1])[0][0])  # pick first id
    0

    Raises
    ------
    NetworkXException
        If the graph G is empty, non connected or has negative weights.

    See Also
    --------
    betweenness_centrality

    Notes
    -----
    Lower values of second order centrality indicate higher centrality.

    The algorithm is from Kermarrec, Le Merrer, Sericola and Trédan [1]_.

    This code implements the analytical version of the algorithm, i.e., there
    is no simulation of a random walk process involved. The random walk is here
    unbiased (corresponding to eq 6 of the paper [1]_), thus the centrality
    values are the standard deviations for random walk return times on the
    transformed input graph G (equal in-degree at each nodes by adding
    self-loops).

    Complexity of this implementation, made to run locally on a single machine,
    is O(n^3), with n the size of G, which makes it viable only for small
    graphs.

    References
    ----------
    .. [1] Anne-Marie Kermarrec, Erwan Le Merrer, Bruno Sericola, Gilles Trédan
       "Second order centrality: Distributed assessment of nodes criticity in
       complex networks", Elsevier Computer Communications 34(5):619-628, 2011.
    """
    import numpy as np

    n = len(G)

    if n == 0:
        raise nx.NetworkXException("Empty graph.")
    if not nx.is_connected(G):
        raise nx.NetworkXException("Non connected graph.")
    if any(d.get("weight", 0) < 0 for u, v, d in G.edges(data=True)):
        raise nx.NetworkXException("Graph has negative edge weights.")

    # balancing G for Metropolis-Hastings random walks
    G = nx.DiGraph(G)
    in_deg = dict(G.in_degree(weight="weight"))
    d_max = max(in_deg.values())
    for i, deg in in_deg.items():
        if deg < d_max:
            G.add_edge(i, i, weight=d_max - deg)

    P = nx.to_numpy_array(G)
    P /= P.sum(axis=1)[:, np.newaxis]  # to transition probability matrix

    def _Qj(P, j):
        P = P.copy()
        P[:, j] = 0
        return P

    M = np.empty([n, n])
    for i in range(n):
        M[:, i] = np.linalg.solve(
            np.identity(n) - _Qj(P, i), np.ones([n, 1])[:, 0]
        )  # eq 3

    return dict(
        zip(G.nodes, [np.sqrt(2 * np.sum(M[:, i]) - n * (n + 1)) for i in range(n)])
    )  # eq 6