# 【題解】LeetCode 1373. Maximum Sum BST in Binary Tree

/**
* Definition for a binary tree node.
* struct TreeNode {
*     int val;
*     TreeNode *left;
*     TreeNode *right;
*     TreeNode() : val(0), left(nullptr), right(nullptr) {}
*     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
*     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
public:
int ans = 0, dp[40005], mx[40005], mn[40005], idx = 1, inf = 1e9;
void dfs(TreeNode* x, int id){
TreeNode *l = x->left, *r = x->right;
mx[id] = x->val;
mn[id] = x->val;
int s = 0;
if (l){
idx++;
int lid = idx;
dfs(l, lid);
mn[id] = mn[lid];
if (dp[lid] == inf || mx[lid] >= x->val) dp[id] = inf;
else s += dp[lid];
}
if (r){
idx++;
int rid = idx;
dfs(r, rid);
mx[id] = mx[rid];
if (dp[rid] == inf || x->val >= mn[rid]) dp[id] = inf;
else s += dp[rid];
}
if (dp[id] == inf) return;
dp[id] = s+x->val;
ans = max(ans, dp[id]);
/*
cout << x->val << "  mx: " << mx[id] << "  mn: " << mn[id] << "\n";
if (l) cout << "l: " << l->val;
else cout << "l: null";
if (r) cout << "  r: " << r->val << "\n";
else cout << "  r: null\n";
cout << dp[id] << " " << ans << "\n";
*/
}
int maxSumBST(TreeNode* root) {
dfs(root, 1);
return ans;
}
};