【題解】ZeroJudge c463: apcs 樹狀圖分析 (Tree Analyses)

【範例】樹上的DFS
【題目敘述】https://zerojudge.tw/ShowProblem?problemid=c463

  • 利用DFS遍歷一棵樹,計算每個節點的高度。
  • 陣列 p[ ]:紀錄每個節點的parent。
  • 陣列 h[ ]:紀錄每個節點的高度。
  • vector <int> v:紀錄每個節點的child。
  • 遞迴函式 func( ):利用 DFS 遍歷每個node 並計算其高度,一個node的高度定義為:它所有的子節點中高度最大者加 1。
    • 遞迴終止條件:抵達 leaf 後。
    • 依題意,leaf 的高度為 0。
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;

int n, h[100000], p[100000], num, tmp, root;
long long ans;
vector <int> v[100000];

void func(int num){
    if (v[num].empty()){
        h[num] = 0;
    }else{
        int mx = 0;
        for (int i:v[num]){
            func(i);
            mx = max(mx, h[i]+1);
        }
        h[num] = mx;
    }
}

int main() {
    while (cin >> n){
        memset(p, -1, sizeof(p));
        for (int i = 0; i < n; i++){
            v[i].clear();
            cin >> num;
            for (int j = 0; j < num; j++){
                cin >> tmp;
                // 轉換成zero-based
                tmp--;
                v[i].push_back(tmp);
                p[tmp] = i;
            }
        }
        for (int i = 0; i < n; i++){
            if (p[i] == -1){
                root = i;
                break;
            }
        }
        func(root);
        ans = 0;
        for (int i = 0; i < n; i++){
            ans += h[i];
        }
        cout << root+1 << "\n" << ans << "\n";
    }
}

Python 程式碼 (credit: Amy Chou)
跟C++用一樣的做法:NA (score:95%),最後一筆測資會遇到Segmentation fault (core dumped),可能是遞迴太深造成的。

# NA (score:95%)
# Segmentation fault (core dumped)
import sys
sys.setrecursionlimit(1000000)

def dfs(x):
    global h
    if len(g[x]) == 0:
        h[x] = 0
    else:
        mx = 0
        for i in g[x]:
            dfs(i)
            mx = max(mx, h[i] + 1)
        h[x] = mx
        
for line in sys.stdin:
    n = int(line.strip())
    g = [[] for _ in range(n)]
    pa = [-1 for _ in range(n)]
    for i in range(n):
        tmp = list(map(int, sys.stdin.readline().strip().split()))
        if tmp[0] > 0:
            g[i] = [k-1 for k in tmp[1:]]
            for j in g[i]:
                pa[j] = i
    for i in range(n):
        if pa[i] == -1:
            root = i
            break
        
    h = [0 for _ in range(n)]
    dfs(root)
    ans = 0
    for i in range(n):
        ans += h[i]
        
    print(root + 1)
    print(ans)

Python 換個做法:AC

# AC (0.1s, 7.2MB)
while True:
    try:
        n = int(input())
        
        p = [-1] * (n+1) # parent
        d = [-1] * (n+1) # distance
        d[0] = 0
        num = [0] * (n+1) # childrenCount
        for i in range(1, n+1):
            temp = list(map(int, input().split()))
            num[i] = temp.pop(0)
            for child in temp:
                p[child] = i
            
        child = []
        for i in range(1, n+1):
            if p[i] == -1:
                root = i
            if num[i] == 0:
                child.append(i)
                d[i] = 0
        
        while child:
            node = child.pop(0)
            d[p[node]] = max(d[p[node]], d[node] + 1)
            num[p[node]] -= 1
            if num[p[node]] == 0 and p[p[node]] != -1:
                child.append(p[node])
        
        print(root)
        print(sum(d))
    except:
        break
分享本文 Share with friends