230. Kth Smallest Element in a BST

https://leetcode.com/problems/kth-smallest-element-in-a-bst/

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

 

Example 1:

Input: root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3

Follow up:
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

 

Constraints:

  • The number of elements of the BST is between 1 to 10^4.
  • You may assume k is always valid, 1 ≤ k ≤ BST's total elements.
---


/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode() {}
* TreeNode(int val) { this.val = val; }
* TreeNode(int val, TreeNode left, TreeNode right) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/
class Solution {
int ans = 0, count = 0;
public int kthSmallest(TreeNode root, int k) {
if (root == null) {
return 0;
}
dfs(root, k);
return ans;
}
private void dfs(TreeNode node, int k) {
if (node.left != null) {
dfs(node.left, k);
}
count++;
if (count == k) {
ans = node.val;
return;
}
if (node.right != null) {
dfs(node.right, k);
}
}
}
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode() {}
* TreeNode(int val) { this.val = val; }
* TreeNode(int val, TreeNode left, TreeNode right) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/
class Solution {
public int kthSmallest(TreeNode root, int k) {
if (root == null) {
return 0;
}
Stack<TreeNode> stack = new Stack<TreeNode>();
while (!stack.isEmpty() || root != null) {
while (root != null) {
stack.push(root);
root = root.left;
}
root = stack.pop();
if (k == 1) {
break;
}
k--;
root = root.right;
}
return root.val;
}
}
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
stack = []
while True:
while root:
stack.append(root)
root = root.left
root = stack.pop()
k -= 1
if k == 0:
return root.val
root = root.right
return -1