Birthday Triplets

  • + 1 comment

    This problem may be one of my favorites, because it permits so much completely different approaches. (I'm none of these ICPC wizards who hop on a problem and code and next problem...).

    Possible approaches:

    1. Direct solution (formula) -> dont know if possible, looks cumbersome
    2. Brute force on 2 numbers, estimate bounds -> done, see notes
    3. Number theory: solve modulo some numbers and apply CRT -> done, see notes
    4. Approximation (eg. Newton) in the real numbers -> difficult, see notes
    5. -> Editorial: Brute force on 1 number + formula for other 2 numbers

    Beforehand one can observe:

    • the difficulty lies in finding a,b,c, the series are trivial
    • the order of a,b,c don't matter
    • there is only one positive solution (up to order)

    Now some notes to the approaches:

    (2) Brute force on c and b: the growth rate of f2..f3..f4 is dominated by the biggest term = c, that is one can get upper and lower bounds for c from f4 and from f4/f2 (and similar bounds for b). However that leaves to many combinations to probe, so I did precompute valid triples (g2,g3,g4) where gi where possible sums of 2 numbers (^2,^3,^4) mod something, and used that to filter c

    (3) Number theory: Maybe the cleanest and fastest approach: precompute possible solutions - dictionary of (a,b,c) with key (f2,f3,f4) - modulo some primes, I used 23,29,31. From f2,f3,f4 you can lookup candidates (a,b,c mod primes =(CRT)=> original a,b,c) amongst them is the solution.

    (3) Approximation: It turns out that there is a strong dependency from a good starting point. I tried to tackle this by using a grid of starting points, but was much to slow, and even then failed to converge for some test cases.

    ## - - - - - - -
    # brute force with estimation of range
    
    validx2 = None
    def solve_bf(f2,f3,f4):
        ''' Brute force with estimation of biggest number based on xi^4/xi^2:
            f <= c^2 <= gamma * f    with f=f4/f2 and gamma = (sqrt(3)+1)/2
            g <= b^2 <= beta * g  with g=(f4-c^4)/(f2-c^2) and beta = (sqrt(2)+1)/2 '''
        global validx2
        if not validx2:
            validx2 = {}
            for k in [81,32,125,49]:
                validx2[k] = set( tuple((a**i + b**i)%k  for i in (2,3,4))
                            for a in range(k) for b in range(k)  )
        found = []
        f = f4/f2
        gamma_f = (math.sqrt(3)+1) / 2 * f
        c2max = min(14_997**2,f2-5,math.sqrt(f4-17),gamma_f+EPS)
        count = 0
        for c in range(int(math.sqrt(f-EPS)),int(math.sqrt(c2max))+1):
            c2,c3,c4 = f2-c*c,f3-c**3,f4-c**4
            if any((c2%k,c3%k,c4%k) not in validx2[k] for k in [81,32,125,49]):
                continue
            g = (f4-c**4)/(f2-c*c)
            beta_g = (math.sqrt(2)+1) / 2 * g
            b2max = min((c-1)**2,f2-c*c-1,beta_g+EPS)
            for b in range(int(math.sqrt(g-EPS)),int(math.sqrt(b2max))+1):
                count += 1
                # check valid a
                a2 = f2-c*c-b*b
                if a2 < b*b and a2*a2 == f4 - c**4 -b**4:
                    a = math.isqrt(a2)
                    if a*a == a2 and a**3+b**3+c**3==f3:
                        found.append((a,b,c))
        # print('count:',count)
        found.sort()
        return found[0]
    
    
    ## - - - - - - -
    # number theoretical approach - solve modulo and use CRT
    
    from collections import defaultdict
    from itertools import combinations_with_replacement as combwr, permutations, product
    
    validx3 = None
    def solve_nt(f2,f3,f4):
        '''Number theory: lookup table for a,b,c (mod some primes) and use CRT.'''
        global validx3
        mods = [23,29,31]
        if validx3 is None:
            validx3 = {}
            for p in mods:
                pmap = defaultdict(set)
                for a,b,c in combwr(range(p),3):
                    pmap[tuple((a**i+b**i+c**i)%p for i in [2,3,4])].add((a,b,c))
                validx3[p] = pmap
        all_remainders = [pmap[(f2%p,f3%p,f4%p)] for p,pmap in validx3.items()]
        for combo in product(*all_remainders):
            c23 = combo[0]
            for c29 in permutations(combo[1]):
                for c31 in permutations(combo[2]):
                    a = [
                        multicrt(((23,c23[i]),(29,c29[i]),(31,c31[i]))) for i in range(3)
                        ]
                    if all(sum(aa**k for aa in a)==fk for k,fk in ((2,f2),(3,f3),(4,f4))):
                        print('...',sorted(a))
                        return sorted(a,reverse=True)
        return None
    
    
    def multicrt(mrs): # multi remainders: reduce by applying CRT sucessively
        mrs = list(mrs)
        m,r = mrs.pop()
        while mrs:
            m1,r1 = mrs.pop()
            r = crt(m,r,m1,r1)
            m *= m1
        return r
    
    def crt(m1,r1,m2,r2):
        g,b1,b2 = xgcd(m1,m2) # seems like this factors are called Bezout coefficients
        if g != 1: # b1*m1+b2*m2 = 1
            raise ValueError('Moduli not relative prime')
        return (r2*b1*m1+r1*b2*m2) % (m1*m2)
    
    def xgcd(a,b):
        x0,x1,y0,y1 = 0,1,1,0
        while a!=0:
            q,b,a = b//a,a,b%a
            y0,y1 = y1,y0-q*y1
            x0,x1 = x1,x0-q*x1
        return b,x0,y0