Linked List in Binary Tree
Given a binary tree root
and a linked list with head
as the first node.
Return True if all the elements in the linked list starting from the head
correspond to some downward path connected in the binary tree otherwise return False.
In this context downward path means a path that starts at some node and goes downwards.
Example 1:

Input: head = [4,2,8], root = [1,4,4,null,2,2,null,1,null,6,8,null,null,null,null,1,3]
Output: true
Explanation: Nodes in blue form a subpath in the binary Tree.
Example 2:

Input: head = [1,4,2,6], root = [1,4,4,null,2,2,null,1,null,6,8,null,null,null,null,1,3]
Output: true
Example 3:
Input: head = [1,4,2,6,8], root = [1,4,4,null,2,2,null,1,null,6,8,null,null,null,null,1,3]
Output: false
Explanation: There is no path in the binary tree that contains all the elements of the linked list from head.
Constraints:
1 <= node.val <= 100
for each node in the linked list and binary tree.The given linked list will contain between
1
and100
nodes.The given binary tree will contain between
1
and2500
nodes.
class Solution {
// Brute Force O(M*N)
public boolean isSubPath(ListNode head, TreeNode root) {
if (head == null)
return true;
if (root == null)
return false;
return dfs(head, root) || isSubPath(head, root.left) || isSubPath(head, root.right);
}
private boolean dfs(ListNode head, TreeNode root) {
if (head == null)
return true;
if (root == null)
return false;
return head.val == root.val && (dfs(head.next, root.left) || dfs(head.next, root.right));
}
// KMP Solution
int[] needle, lps;
public boolean isSubPath(ListNode head, TreeNode root) {
needle = convertLinkedListToArray(head);
lps = computeKMPTable(needle);
return kmpSearch(root, 0);
}
boolean kmpSearch(TreeNode i, int j) {
if (j == needle.length)
return true;
if (i == null)
return false;
while (j > 0 && i.val != needle[j])
j = lps[j - 1];
if (i.val == needle[j])
j++;
return kmpSearch(i.left, j) || kmpSearch(i.right, j);
}
int[] computeKMPTable(int[] pattern) {
int n = pattern.length;
int[] lps = new int[n];
for (int i = 1, j = 0; i < n; i++) {
while (j > 0 && pattern[i] != pattern[j])
j = lps[j - 1];
if (pattern[i] == pattern[j])
lps[i] = ++j;
}
return lps;
}
int[] convertLinkedListToArray(ListNode head) {
List<Integer> list = new ArrayList<>();
while (head != null) {
list.add(head.val);
head = head.next;
}
int[] arr = new int[list.size()];
for (int i = 0; i < list.size(); i++)
arr[i] = list.get(i);
return arr;
}
}
Last updated