Lego Blocks

  • + 0 comments

    Main points to take care about are: 1) calculate total number of combination per row using dynamic programming 2) inclusion/exclusion concept by splitting the wall on 2 parts - use dynamic programming for better performance 3) Fast/binary exponentiation by squaring the base and halving the exponent

    Java code

    public static long[] combinations;
        public static long modulo = 1000000007;
        private static long [] totals;
        private static long [] goods;
        private static long [] bads;
        public static int legoBlocks(int n, int m) {
            totals = new long[m+1];
            goods = new long[m+1];
            bads = new long[m+1];
            totals[0] = totals[1] = goods[0] = goods[1] = 1L;
            combinations = new long[m+1];
            combinations[0]=1L;
            long good = calculateGood(n,m);
            return (int) (good%modulo);
        }
        
        public static Long calculateGood(int n, int m) {
            if(goods[m]!=0) {
                return goods[m];
            }
            Long total = calculateTotal(n,m);
            Long bad = calculateBad(n,m);
            Long good = (total - bad)%modulo;
            if(good<0){
                good+=modulo;
            }
            goods[m] = good;
            return goods[m];
        }
        
        private static Long calculateBad(int n, int m) {
            if(bads[m]!=0){
                return bads[m];
            }
            Long totalBad = 0L;
            for(int i =1; i< m;i++) {            
                Long currentBad = (calculateGood(n, i)*calculateTotal(n, m-i))%modulo;
                totalBad=(totalBad+currentBad)%modulo;
            }
            bads[m]=totalBad%modulo;
            return totalBad;
        }
        
        private static Long calculateTotal(int n, int m){
            if(totals[m]!=0){
                return totals[m];
            }
            Long perRow = calculateCombinationsPerRow(m);
            
            long total = modPow(perRow, n);
            totals[m] = total%modulo;
            return totals[m];
        }
        
        private static long modPow(long base, int exponent) {
            long result = 1L;
            base = base % modulo;
            while (exponent > 0) {
                if (exponent%2 == 1) {
                    result = (result * base) % modulo;
                }
                base = (base * base) % modulo;
                exponent /=2;
            }
        
            return result;
        }
        
        private static Long calculateCombinationsPerRow(int width){
            
            if(combinations[width]!=0){
                return combinations[width];
            }
            long total = 0L;
            if(width-4>=0){
                total += calculateCombinationsPerRow(width-4);
            }
            if(width-3>=0){
                total += calculateCombinationsPerRow(width -3);
            }
            if(width-2>=0){
                total += calculateCombinationsPerRow(width -2);
            }
            if(width-1>=0){
                total += calculateCombinationsPerRow(width -1);
            }
            combinations[width] = total%modulo;
            return combinations[width];
        }