> For the complete documentation index, see [llms.txt](https://mayanktyagi3111.gitbook.io/interview-prep/llms.txt). Markdown versions of documentation pages are available by appending `.md` to page URLs; this page is available as [Markdown](https://mayanktyagi3111.gitbook.io/interview-prep/priority-queue/find-the-kth-smallest-sum-of-a-matrix-with-sorted-rows.md).

# Find the Kth Smallest Sum of a Matrix With Sorted Rows

You are given an `m * n` matrix, `mat`, and an integer `k`, which has its rows sorted in non-decreasing order.

You are allowed to choose exactly 1 element from each row to form an array. Return the Kth **smallest** array sum among all possible arrays.

**Example 1:**

```
Input: mat = [[1,3,11],[2,4,6]], k = 5
Output: 7
Explanation: Choosing one element from each row, the first k smallest sum are:
[1,2], [1,4], [3,2], [3,4], [1,6]. Where the 5th sum is 7.  
```

**Example 2:**

```
Input: mat = [[1,3,11],[2,4,6]], k = 9
Output: 17
```

**Example 3:**

```
Input: mat = [[1,10,10],[1,4,5],[2,3,6]], k = 7
Output: 9
Explanation: Choosing one element from each row, the first k smallest sum are:
[1,1,2], [1,1,3], [1,4,2], [1,4,3], [1,1,6], [1,5,2], [1,5,3]. Where the 7th sum is 9.  
```

**Example 4:**

```
Input: mat = [[1,1,10],[2,2,9]], k = 7
Output: 12
```

**Constraints:**

* `m == mat.length`
* `n == mat.length[i]`
* `1 <= m, n <= 40`
* `1 <= k <= min(200, n ^ m)`
* `1 <= mat[i][j] <= 5000`
* `mat[i]` is a non decreasing array.

```java
class Solution {
    public int kthSmallest(int[][] mat, int k) {
        int[] result = mat[0];
        for (int i = 1; i < mat.length; i++)
            result = kSmallestPairs(result, mat[i], k);
        return result[k - 1];
    }

    public int[] kSmallestPairs(int[] nums1, int[] nums2, int k) {
        int nlen1 = nums1.length, nlen2 = nums2.length;
        List<Integer> res = new ArrayList<>();
        if (nlen1 == 0 || nlen2 == 0 || k == 0)
            return new int[0];
        PriorityQueue<List<Integer>> pq = new PriorityQueue<>((a, b) -> a.get(0) + a.get(1) - b.get(0) - b.get(1));
        for (int i = 0; i < nums1.length && i < k; i++) {
            List<Integer> pair = new ArrayList<>();
            pair.addAll(Arrays.asList(nums1[i], nums2[0], 0));
            pq.add(pair);
        }
        while (k-- > 0 && !pq.isEmpty()) {
            List<Integer> cur = pq.poll();
            res.add(cur.get(0) + cur.get(1));
            // reach end of nums2
            if (cur.get(2) == nums2.length - 1)
                continue;
            List<Integer> p2 = new ArrayList<>();
            p2.addAll(Arrays.asList(cur.get(0), nums2[cur.get(2) + 1], cur.get(2) + 1));
            pq.add(p2);
        }
        int[] ans = new int[res.size()];
        for (int i = 0; i < res.size(); i++)
            ans[i] = res.get(i);
        return ans;
    }
}
```
