算法练习 Leetcode 222 Count Complete Tree Nodes 解法

leetcode 222 解法

题目如下:

Definition of a complete binary tree from Wikipedia:
In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.

Example:

Input: 
    1
   / \
  2   3
 / \  /
4  5 6

Output: 6

最简单的做法就是直接遍历一边,不管是DFS还是BFS,时间复杂度都是O(N)。但是这样就没利用到complete binary tree的条件。优化的解法如果没练过的话其实不太好想。这题要的其实是找到最后一层有几个node。而每层的node的index就是在[2^h, 2^(h+1)-1]区间之内。如果有个function可以告诉你某个index在不在tree里,那么可以对[2^h, 2^(h+1)-1]做binary search,找到其中index最大并且在tree里的node。Binary tree的最后一层大概有N/2个node(每加一层,node数量翻倍!),那么需要做O(lgN)次search。

至于这个告诉某个node在不在tree里的函数可以这样写:写几个例子观察一下,如果一个node的index是n, 每个node的parent的index都是floor(n/2)。那么我们就能把从root到达目标node的path找出来,比如如果目标node是5,那么path就是[5, 2, 1],入股目标是8,那么path就是[8, 4, 2, 1]。然后从root,也就是1号node出发,如果最后我们按照这个地图能走到目标,那么这个node就在tree里面,否则就不在。这个function需要O(lgN)时间和O(lgN)空间。

最后总的时间复杂度这样计算,O(lgN)次search,每次 O(lgN)时间,总时间O((lgN)^2)。空间O(lgN)用来保存path。

代码如下

C++

class Solution {
public:
    int countNodes(TreeNode* root) {
        if (root == NULL) {
            return 0;
        }
        int h = findHeight(root);
        int i_min = pow(2, h);
        int i_max = pow(2, h + 1) - 1;
        int ans = i_min;
        while(i_min <= i_max) {
            int m = i_min + (i_max - i_min) / 2;
            if (inTree(root, m)) {
                ans = m;
                i_min = m + 1;
            } else{
                i_max = m - 1;
            }
        }   
        return ans;
        
    }
    
    int findHeight(TreeNode* root) {
        int h = 0;
        TreeNode* cur = root;
        while (cur != NULL) {
            cur = cur->left;
            h += 1;
        }
        return h - 1;
    }
    
    bool inTree(TreeNode* root, int i) {
        if (root == NULL || i < 0) {
            return false;
        }
        vector<int> path;
        int cur = i;
        while (cur != 0) {
            path.push_back(cur);
            cur /= 2;
        }
        TreeNode* curNode = root;
        for (int j = path.size() - 2; j >= 0; j--) {
            if (path[j] == path[j + 1] * 2) {
                curNode = curNode->left;
            } else {
                curNode = curNode->right;
            }
            if (curNode == NULL) {
                return false;
            }
        }
        return true;
    }
};

Java

class Solution {
    public int countNodes(TreeNode root) {
        if (root == null) {
            return 0;
        }
        int h = findHeight(root);
        int i_min = (int)Math.pow(2, h);
        int i_max = (int)(Math.pow(2, h + 1) - 1);
        int ans = i_min;
        while (i_min <= i_max) {
            int m = i_min + (i_max - i_min) / 2;
            if (inTree(root, m)) {
                ans = m;
                i_min = m + 1;
            } else{
                i_max = m - 1;
            }
        }   
        return ans;
        
    }
    
    public int findHeight(TreeNode root) {
        int h = 0;
        TreeNode cur = root;
        while (cur != null) {
            cur = cur.left;
            h += 1;
        }
        return h - 1;
    }
    
    public boolean inTree(TreeNode root, int i) {
        if (root == null || i < 0) {
            return false;
        }
        List<Integer> path = new ArrayList<>();
        int cur = i;
        while (cur != 0) {
            path.add(cur);
            cur /= 2;
        }
        TreeNode curNode = root;
        for (int j = path.size() - 2; j >= 0; j--) {
            if (path.get(j) == path.get(j + 1) * 2) {
                curNode = curNode.left;
            } else {
                curNode = curNode.right;
            }
            if (curNode == null) {
                return false;
            }
        }
        return true;
    }
}