99. 恢复二叉搜索树
给你二叉搜索树的根节点
root
,该树中的两个节点被错误地交换。请在不改变其结构的情况下,恢复这棵树。进阶:使用 O(n) 空间复杂度的解法很容易实现。你能想出一个只使用常数空间的解决方案吗?
示例 1:

输入:root = [1,3,null,null,2] 输出:[3,1,null,null,2] 解释:3 不能是 1 左孩子,因为 3 > 1 。交换 1 和 3 使二叉搜索树有效。
示例 2:

输入:root = [3,1,4,null,null,2] 输出:[2,1,4,null,null,3] 解释:2 不能在 3 的右子树中,因为 2 < 3 。交换 2 和 3 使二叉搜索树有效。
提示:
- 树上节点的数目在范围
[2, 1000]
内
2
31
<= Node.val <= 2
31
- 1
法1 递归
思路
中序遍历过程中,记录错误两个错误排序节点,最后进行交换
对二叉搜索树进行 中序遍历的时候 访问到的元素是从小到大顺序排列的。所以遍历的时候,记住上一个访问的元素,然后和当前元素相比。
一旦违反了升序的规则,那就说明这两个节点其中一个必有问题,难就难在要想清楚在这个问题中:
- 这个规定只能违反一次,或两次
- 如果违反一次,那么这两个节点均有问题
- 如果是两次违反,那么第一次违反是prev的问题,第二次违反是当前节点的问题。
所以实现的时候,采用两个变量
p1 p2
记录违反规则的节点,初始均为None
:- 拿
p2
指针指向违反规则时的当前节点,并总是更新。这样一样不管违反几次,p2
总是指向当前节点。满足条件。
- 拿
p1
指向违反规则的pre
节点。指向的时候,判断p1是否为空,如果为空,说明是首次违反的pre
节点,赋值即可。如果不为空,说明已经是第二次违反规则了,此时p1
不更新
图解
如下二叉搜索树,左边的为有问题的二叉搜索树,右边为恢复的二叉搜索树

中序遍历分别如下,可以发现有两处违反了升序的规则

那么
p1
和p2
指针分别应该指向第一次违反规则的pre
,和第二次违反规则的当前node

题解
# Definition for a binary tree node. # class TreeNode: # def __init__(self, val=0, left=None, right=None): # self.val = val # self.left = left # self.right = right class Solution: def __init__(self): self.p1 = None self.p2 = None self.pre = None def recoverTree(self, root: TreeNode) -> None: """ Do not return anything, modify root in-place instead. """ self.inOrder(node=root) self.p1.val, self.p2.val = self.p2.val, self.p1.val def inOrder(self, node): if node is None: return self.inOrder(node.left) if self.pre and node.val < self.pre.val: # 第一个记录节点只记录初次违反规则时的pre节点 if self.p1 is None: self.p1 = self.pre # 第二个节点要一直更新 self.p2 = node self.pre = node self.inOrder(node.right)
法 2 迭代写法
思路
中序遍历节点,最多有两对节点出错。
如果只有一对节点出错,说明交换的两个节点挨着的。
如果有两对节点出错,那么是不挨着的,此时交换第一对第一个 和 第二对第二个。
总之不管错了几对,一定是交换 所有出错节点的 第一个 和最后一个。
因此,可以把出错节点存起来,然后最后交换第一个和最后一个的节点值
题解
# Definition for a binary tree node. # class TreeNode: # def __init__(self, val=0, left=None, right=None): # self.val = val # self.left = left # self.right = right class Solution: def recoverTree(self, root: Optional[TreeNode]) -> None: """ Do not return anything, modify root in-place instead. """ # 存放 出错节点 error = [] pre = None cur = root stack = [] while stack or cur: if cur: stack.append(cur) cur = cur.left else: cur = stack.pop() if pre and pre.val > cur.val: error.append(pre) error.append(cur) pre = cur cur = cur.right # 交换 error[-1].val, error[0].val = error[0].val, error[-1].val