Kth Ancestor of a Tree Node

You are given a tree with n nodes numbered from 0 to n-1 in the form of a parent array where parent[i] is the parent of node i. The root of the tree is node 0.

Implement the function getKthAncestor(int node, int k) to return the k-th ancestor of the given node. If there is no such ancestor, return -1.

The k-th ancestor of a tree node is the k-th node in the path from that node to the root.

Example:

Input:
["TreeAncestor","getKthAncestor","getKthAncestor","getKthAncestor"]
[[7,[-1,0,0,1,1,2,2]],[3,1],[5,2],[6,3]]

Output:
[null,1,0,-1]

Explanation:
TreeAncestor treeAncestor = new TreeAncestor(7, [-1, 0, 0, 1, 1, 2, 2]);

treeAncestor.getKthAncestor(3, 1);  // returns 1 which is the parent of 3
treeAncestor.getKthAncestor(5, 2);  // returns 0 which is the grandparent of 5
treeAncestor.getKthAncestor(6, 3);  // returns -1 because there is no such ancestor

Constraints:

  • 1 <= k <= n <= 5*10^4

  • parent[0] == -1 indicating that 0 is the root node.

  • 0 <= parent[i] < n for all 0 < i < n

  • 0 <= node < n

  • There will be at most 5*10^4 queries.

class TreeAncestor {
    // Map of Parent -> Children array
    Map<Integer, List<Integer>> Tree = new HashMap<>();
    // DP[u][i] -> Ancestor at a distance of 2^i from u
    Integer[][] DP;

    public TreeAncestor(int N, int[] parent) {
        int logN = (int) Math.ceil(Math.log(N) / Math.log(2)) + 1;
        DP = new Integer[N][logN];
        for (int i = 0; i < N; i++) {
            int currNode = i;
            int parentNode = parent[i];
            // If it is not root node, then it will have a parent
            if (i > 0) {
                Tree.computeIfAbsent(parentNode, value -> new ArrayList<>()).add(currNode);
                DP[currNode][0] = parent[i]; // 2^0 = 1 => parent
            }
        }
        // 0 -> Is the root node
        DFS(0);
    }

    // O(NlogN)
    private void DFS(int cur) {
        // Recurrence relation
        for (int i = 1; DP[cur][i - 1] != null; i++) {
            // 2^i = 2^(i-1) + 2^(i-1)
            int nodeAtHalfDistance = DP[cur][i - 1];
            DP[cur][i] = DP[nodeAtHalfDistance][i - 1];
        }
        for (int child : Tree.getOrDefault(cur, new ArrayList<>()))
            DFS(child);
    }

    public int getKthAncestor(int node, int k) {
        int currPower = 0;
        // Similar to how we find power in logN
        while (k > 0) {
            // If 'k' is odd
            if (k % 2 == 1) {
                if (DP[node][currPower] == null)
                    return -1;
                node = DP[node][currPower];
            }
            currPower++;
            k = k / 2;
        }
        return node;
    }
}

Last updated