You are given a tree with n nodes numbered from 0 to n-1 in the form of a parent array where parent[i] is the parent of node i. The root of the tree is node 0.
Implement the function getKthAncestor(int node, int k) to return the k-th ancestor of the given node. If there is no such ancestor, return -1.
The k-th ancestor of a tree node is the k-th node in the path from that node to the root.
Example:
Input:
["TreeAncestor","getKthAncestor","getKthAncestor","getKthAncestor"]
[[7,[-1,0,0,1,1,2,2]],[3,1],[5,2],[6,3]]
Output:
[null,1,0,-1]
Explanation:
TreeAncestor treeAncestor = new TreeAncestor(7, [-1, 0, 0, 1, 1, 2, 2]);
treeAncestor.getKthAncestor(3, 1); // returns 1 which is the parent of 3
treeAncestor.getKthAncestor(5, 2); // returns 0 which is the grandparent of 5
treeAncestor.getKthAncestor(6, 3); // returns -1 because there is no such ancestor
Constraints:
1 <= k <= n <= 5*10^4
parent[0] == -1 indicating that 0 is the root node.
0 <= parent[i] < n for all 0 < i < n
0 <= node < n
There will be at most 5*10^4 queries.
class TreeAncestor {
// Map of Parent -> Children array
Map<Integer, List<Integer>> Tree = new HashMap<>();
// DP[u][i] -> Ancestor at a distance of 2^i from u
Integer[][] DP;
public TreeAncestor(int N, int[] parent) {
int logN = (int) Math.ceil(Math.log(N) / Math.log(2)) + 1;
DP = new Integer[N][logN];
for (int i = 0; i < N; i++) {
int currNode = i;
int parentNode = parent[i];
// If it is not root node, then it will have a parent
if (i > 0) {
Tree.computeIfAbsent(parentNode, value -> new ArrayList<>()).add(currNode);
DP[currNode][0] = parent[i]; // 2^0 = 1 => parent
}
}
// 0 -> Is the root node
DFS(0);
}
// O(NlogN)
private void DFS(int cur) {
// Recurrence relation
for (int i = 1; DP[cur][i - 1] != null; i++) {
// 2^i = 2^(i-1) + 2^(i-1)
int nodeAtHalfDistance = DP[cur][i - 1];
DP[cur][i] = DP[nodeAtHalfDistance][i - 1];
}
for (int child : Tree.getOrDefault(cur, new ArrayList<>()))
DFS(child);
}
public int getKthAncestor(int node, int k) {
int currPower = 0;
// Similar to how we find power in logN
while (k > 0) {
// If 'k' is odd
if (k % 2 == 1) {
if (DP[node][currPower] == null)
return -1;
node = DP[node][currPower];
}
currPower++;
k = k / 2;
}
return node;
}
}