Skip to content

2736. Maximum Sum Queries 👍

  • Time: $O(\texttt{sort}(n) + \texttt{sort}(q) + q\log n)$
  • Space: $O(n + q)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
struct Pair {
  int x;
  int y;
};

struct IndexAndQuery {
  int index;  // the index in `queries`
  int minX;   // queries[i] := (minX, minY)
  int minY;
};

class Solution {
 public:
  vector<int> maximumSumQueries(vector<int>& nums1, vector<int>& nums2,
                                vector<vector<int>>& queries) {
    const vector<Pair> pairs = getPairs(nums1, nums2);
    const vector<IndexAndQuery> indexAndQueries = getIndexAndQueries(queries);
    vector<int> ans(queries.size());
    vector<pair<int, int>> stack;  // [(y, x + y)]

    int i = 0;
    for (const auto& [index, minX, minY] : indexAndQueries) {
      while (i < pairs.size() && pairs[i].x >= minX) {
        const auto [x, y] = pairs[i++];
        // x + y is a better candidate. Given that x is decreasing, the
        // condition "x + y >= stack.back().second" suggests that y is
        // relatively larger, thereby making it a better candidate.
        while (!stack.empty() && x + y >= stack.back().second)
          stack.pop_back();
        if (stack.empty() || y > stack.back().first)
          stack.emplace_back(y, x + y);
      }
      const auto it = lower_bound(stack.begin(), stack.end(),
                                  pair<int, int>{minY, INT_MIN});
      ans[index] = it == stack.end() ? -1 : it->second;
    }

    return ans;
  }

 private:
  vector<Pair> getPairs(const vector<int>& nums1, const vector<int>& nums2) {
    vector<Pair> pairs;
    for (int i = 0; i < nums1.size(); ++i)
      pairs.push_back({nums1[i], nums2[i]});
    sort(pairs.begin(), pairs.end(),
         [](const Pair& a, const Pair& b) { return a.x > b.x; });
    return pairs;
  }

  vector<IndexAndQuery> getIndexAndQueries(const vector<vector<int>>& queries) {
    vector<IndexAndQuery> indexAndQueries;
    for (int i = 0; i < queries.size(); ++i)
      indexAndQueries.push_back(IndexAndQuery{
          .index = i,
          .minX = queries[i][0],
          .minY = queries[i][1],
      });
    sort(indexAndQueries.begin(), indexAndQueries.end(),
         [](const IndexAndQuery& a, const IndexAndQuery& b) {
      return a.minX > b.minX;
    });
    return indexAndQueries;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class Solution {
  public int[] maximumSumQueries(int[] nums1, int[] nums2, int[][] queries) {
    List<MyPair> pairs = getPairs(nums1, nums2);
    List<IndexAndQuery> indexAndQueries = getIndexAndQueries(queries);
    int[] ans = new int[queries.length];
    List<Pair<Integer, Integer>> stack = new ArrayList<>(); // [(y, x + y)]

    int i = 0;
    for (IndexAndQuery indexAndQuery : indexAndQueries) {
      final int index = indexAndQuery.index;
      final int minX = indexAndQuery.minX;
      final int minY = indexAndQuery.minY;
      while (i < pairs.size() && pairs.get(i).x >= minX) {
        MyPair pair = pairs.get(i++);
        // x + y is a better candidate. Given that x is decreasing, the
        // condition "x + y >= stack.back().second" suggests that y is
        // relatively larger, thereby making it a better candidate.
        final int x = pair.x;
        final int y = pair.y;
        while (!stack.isEmpty() && x + y >= stack.get(stack.size() - 1).getValue())
          stack.remove(stack.size() - 1);
        if (stack.isEmpty() || y > stack.get(stack.size() - 1).getKey())
          stack.add(new Pair<>(y, x + y));
      }

      final int j = firstGreaterEqual(stack, minY);
      ans[index] = j == stack.size() ? -1 : stack.get(j).getValue();
    }

    return ans;
  }

  private record MyPair(int x, int y){};
  private record IndexAndQuery(int index, int minX, int minY){};

  private int firstGreaterEqual(List<Pair<Integer, Integer>> A, int target) {
    int l = 0;
    int r = A.size();
    while (l < r) {
      final int m = (l + r) / 2;
      if (A.get(m).getKey() >= target)
        r = m;
      else
        l = m + 1;
    }
    return l;
  }

  private List<MyPair> getPairs(int[] nums1, int[] nums2) {
    List<MyPair> pairs = new ArrayList<>();
    for (int i = 0; i < nums1.length; ++i)
      pairs.add(new MyPair(nums1[i], nums2[i]));
    pairs.sort((a, b) -> b.x - a.x);
    return pairs;
  }

  private List<IndexAndQuery> getIndexAndQueries(int[][] queries) {
    List<IndexAndQuery> indexAndQueries = new ArrayList<>();
    for (int i = 0; i < queries.length; ++i)
      indexAndQueries.add(new IndexAndQuery(i, queries[i][0], queries[i][1]));
    indexAndQueries.sort((a, b) -> b.minX - a.minX);
    return indexAndQueries;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Pair:
  def __init__(self, x: int, y: int):
    self.x = x
    self.y = y

  def __iter__(self):
    yield self.x
    yield self.y


class IndexAndQuery:
  def __init__(self, index: int, minX: int, minY: int):
    self.index = index
    self.minX = minX
    self.minY = minY

  def __iter__(self):
    yield self.index
    yield self.minX
    yield self.minY


class Solution:
  def maximumSumQueries(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
    pairs = sorted([Pair(nums1[i], nums2[i])
                   for i in range(len(nums1))], key=lambda p: p.x, reverse=True)
    indexAndQueries = sorted([IndexAndQuery(i, query[0], query[1])
                              for i, query in enumerate(queries)],
                             key=lambda iq: iq.minX, reverse=True)
    ans = [0] * len(queries)
    stack = []  # [(y, x + y)]

    i = 0
    for index, minX, minY in indexAndQueries:
      while i < len(pairs) and pairs[i].x >= minX:
        x, y = pairs[i]
        while stack and x + y >= stack[-1][1]:
          stack.pop()
        if not stack or y > stack[-1][0]:
          stack.append((y, x + y))
        i += 1
      j = self._firstGreaterEqual(stack, minY)
      ans[index] = -1 if j == len(stack) else stack[j][1]

    return ans

  def _firstGreaterEqual(self, A: List[Tuple[int, int]], target: int) -> int:
    l = 0
    r = len(A)
    while l < r:
      m = (l + r) // 2
      if A[m][0] >= target:
        r = m
      else:
        l = m + 1
    return l