• + 0 comments

    from collections import *

    n = int(input()) nodes = list(map(int,input().split(' '))) nodes.insert(0,0) graph = defaultdict(list) for _ in range(n-1): x,y = list(map(int,input().split(' '))) graph[x].append(y) graph[y].append(x)

    parent = {} queue = deque([[1,-1]]) while queue: curr,prev = queue.popleft() parent[curr] = prev for nei in graph[curr]: if nei!=prev: queue.append([nei,curr])

    res = nodes[1] MOD = 10**9+7 for i in range(2,n+1): res*=(nodes[i]-nodes[parent[i]]) res%=MOD print(res)