Problem Statement
Definition of a complete binary tree from Wikipedia:
In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.
Example:
Input:
1
/ \
2 3
/ \ /
4 5 6
Output: 6
The easiest way is to traverse the whole tree directly, whether it is DFS or BFS, the time complexity is O(N). However, the condition of complete binary tree is not used in this way. It is not very easy to come up with the most optimized solution. The point of this question is actually to find how many nodes there are in the bottom layer. The index of each layer’s node is within the interval [2^h, 2^(h+1)-1]
. If there is a function that can tell you whether a certain index is in the tree, you can do a binary search on [2^h, 2^(h+1)-1]
to find the node with the largest index within the tree. The bottom layer of the Binary tree has about N/2 nodes (each additional layer doubles the number of nodes!), then O(lgN) searches are required.
As for the function that tells whether a node is in the tree, you can write it like this: Write a few examples and observe that if the index of a node is n, the index of each node’s parent is floor(n/2)
. Then we can find the path from root to the target node. For example, if the target node is 5, then the path is [5, 2, 1]
. If the target node is 8, then the path is [8, 4, 2, 1]
. Then start from the root, aka. node 1, if we finally reach the target according to this path, then this node is in the tree, otherwise it is not. This function requires O(lgN)
time and O(lgN)
space.
Finally, the total time complexity is calculated as O(lgN)
searches, O(lgN)
time each time, and the total time is O((lgN)^2)
. Space O(lgN)
is used to save path.
Here are the implementations
C++
class Solution {
public:
int countNodes(TreeNode* root) {
if (root == NULL) {
return 0;
}
int h = findHeight(root);
int i_min = pow(2, h);
int i_max = pow(2, h + 1) - 1;
int ans = i_min;
while(i_min <= i_max) {
int m = i_min + (i_max - i_min) / 2;
if (inTree(root, m)) {
ans = m;
i_min = m + 1;
} else{
i_max = m - 1;
}
}
return ans;
}
int findHeight(TreeNode* root) {
int h = 0;
TreeNode* cur = root;
while (cur != NULL) {
cur = cur->left;
h += 1;
}
return h - 1;
}
bool inTree(TreeNode* root, int i) {
if (root == NULL || i < 0) {
return false;
}
vector<int> path;
int cur = i;
while (cur != 0) {
path.push_back(cur);
cur /= 2;
}
TreeNode* curNode = root;
for (int j = path.size() - 2; j >= 0; j--) {
if (path[j] == path[j + 1] * 2) {
curNode = curNode->left;
} else {
curNode = curNode->right;
}
if (curNode == NULL) {
return false;
}
}
return true;
}
};
Java
class Solution {
public int countNodes(TreeNode root) {
if (root == null) {
return 0;
}
int h = findHeight(root);
int i_min = (int)Math.pow(2, h);
int i_max = (int)(Math.pow(2, h + 1) - 1);
int ans = i_min;
while (i_min <= i_max) {
int m = i_min + (i_max - i_min) / 2;
if (inTree(root, m)) {
ans = m;
i_min = m + 1;
} else{
i_max = m - 1;
}
}
return ans;
}
public int findHeight(TreeNode root) {
int h = 0;
TreeNode cur = root;
while (cur != null) {
cur = cur.left;
h += 1;
}
return h - 1;
}
public boolean inTree(TreeNode root, int i) {
if (root == null || i < 0) {
return false;
}
List<Integer> path = new ArrayList<>();
int cur = i;
while (cur != 0) {
path.add(cur);
cur /= 2;
}
TreeNode curNode = root;
for (int j = path.size() - 2; j >= 0; j--) {
if (path.get(j) == path.get(j + 1) * 2) {
curNode = curNode.left;
} else {
curNode = curNode.right;
}
if (curNode == null) {
return false;
}
}
return true;
}
}