本文最后更新于2021年11月3日,已超过 1 个月没更新!

第五讲 动态规划

背包问题

01背包问题

import java.util.*;

public class Main {
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int[] v = new int[n + 1];
        int[] w = new int[n + 1];
        for (int i = 1; i <= n; ++ i) {
            v[i] = sc.nextInt();
            w[i] = sc.nextInt();
        }
        int[][] f = new int[n + 1][m + 1];
        for (int i = 1; i <= n; ++ i) {
            for (int j = 1; j <= m; ++ j) {
                f[i][j] = f[i - 1][j];
                if (j >= v[i]) f[i][j] = Math.max(f[i][j], f[i - 1][j - v[i]] + w[i]);
            }
        }
        System.out.println(f[n][m]);
        return;
    }
}

🍉 优化

import java.util.*;

public class Main {
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int[] v = new int[n + 1];
        int[] w = new int[n + 1];
        for (int i = 1; i <= n; ++ i) {
            v[i] = sc.nextInt();
            w[i] = sc.nextInt();
        }
        int[] f = new int[m + 1];
        for (int i = 1; i <= n; ++ i) {
            for (int j = m; j >= v[i]; -- j) {
                f[j] = Math.max(f[j], f[j - v[i]] + w[i]);
            }
        }
        System.out.println(f[m]);
        return;
    }
}

完全背包问题

容量以内的最大价值

/*
01 背包: f[i][j] = Math.max(f[i - 1][j], f[i - 1][j - v[i]] + w[i])
完全背包: f[i][j] = Math.max(f[i - 1][j], f[i][j - v[i]] + w[i])
*/
import java.util.*;

public class Main {
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int[] v = new int[n + 1];
        int[] w = new int[n + 1];
        for (int i = 1; i <= n; ++ i) {
            v[i] = sc.nextInt();
            w[i] = sc.nextInt();
        }

        int[][] f = new int[n + 1][m + 1];
        for (int i = 1; i <= n; ++ i) {
            for (int j = 1; j <= m; ++ j) {
                f[i][j] = f[i - 1][j];
                if (j >= v[i])
                    f[i][j] = Math.max(f[i][j], f[i][j - v[i]] + w[i]);
            }
        }

        System.out.println(f[n][m]);
        return;
    }
}

🍉 优化

import java.util.*;

public class Main {
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int[] v = new int[n + 1];
        int[] w = new int[n + 1];
        for (int i = 1; i <= n; ++ i) {
            v[i] = sc.nextInt();
            w[i] = sc.nextInt();
        }

        int[] f = new int[m + 1];
        for (int i = 1; i <= n; ++ i) {
            for (int j = v[i]; j <= m; ++ j) {
                f[j] = Math.max(f[j], f[j - v[i]] + w[i]);
            }
        }

        System.out.println(f[m]);
        return;
    }
}

数位成本和为目标值的最大数字(LC)

恰好等于容量的最大价值

class Solution {
    public String largestNumber(int[] cost, int target) {
        int n = cost.length;

        String[][] f = new String[n + 1][target + 1];
        // Arrays.fill(f[0], ""); //若这样初始化则求得的为 容量以内的最大价值
        f[0][0] = ""; // 本题要求 要恰好等于容量

        for (int i = 1; i <= n; ++ i) {
            for (int j = 0; j <= target; ++ j) {
                f[i][j] = f[i - 1][j];
                if (j >= cost[i - 1]) {
                    if (f[i][j - cost[i - 1]] == null) continue; // 关键
                    f[i][j] = max(String.valueOf(i) + f[i][j - cost[i - 1]], f[i][j]);
                }
            }
        }
        return f[n][target] == null ? "0" : f[n][target];
    }

    public String max (String s1, String s2) {
        if (s2 == null) return s1;
        int n = s1.length(), m = s2.length();
        if (n != m) return n > m ? s1 : s2;

        for (int i = 0; i < n; ++ i) {
            if (s1.charAt(i) == s2.charAt(i))
                continue;
            return s1.charAt(i) > s2.charAt(i) ? s1 : s2;
        }
        return s1;
    }
}

🍉 优化

class Solution {
    public String largestNumber(int[] cost, int target) {
        int n = cost.length;

        String[] f = new String[target + 1];
        f[0] = "";

        for (int i = 1; i <= n; ++ i) {
            for (int j = cost[i - 1]; j <= target; ++ j) {
                if (f[j - cost[i - 1]] == null) continue;
                f[j] = max(String.valueOf(i) + f[j - cost[i - 1]], f[j]);
            }
        }
        return f[target] == null ? "0" : f[target];
    }

    public String max (String s1, String s2) {
        if (s2 == null) return s1;
        int n = s1.length(), m = s2.length();
        if (n != m) return n > m ? s1 : s2;

        for (int i = 0; i < n; ++ i) {
            if (s1.charAt(i) == s2.charAt(i))
                continue;
            return s1.charAt(i) > s2.charAt(i) ? s1 : s2;
        }
        return s1;
    }
}

多重背包问题

import java.util.*;

public class Main {
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int[] v = new int[n + 1];
        int[] w = new int[n + 1];
        int[] s = new int[n + 1];
        for (int i = 1; i <= n; ++ i) {
            v[i] = sc.nextInt();
            w[i] = sc.nextInt();
            s[i] = sc.nextInt();
        }

        int[][] f = new int[n + 1][m + 1];
        for (int i = 1; i <= n; ++ i) {
            for (int j = 1; j <= m; ++ j) {
                for (int k = 0; k <= s[i]; ++ k) {
                    if (j >= k * v[i]) f[i][j] = Math.max(f[i][j], f[i - 1][j - k * v[i]] + k * w[i]);
                }
            }
        }
        System.out.println(f[n][m]);
    }
}

多重背包问题 II

import java.util.*;

public class Main {
    static int N = 25000;
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int[] v = new int[N];
        int[] w = new int[N];
        int cnt = 0;

        // 预处理(二进制处理) 转化为01背包问题
        // 将时间复杂度从n3 -> n2logn
        while (n-- > 0) {
            int a = sc.nextInt(), b = sc.nextInt(), s = sc.nextInt();
            int k = 1;
            while (s >= k) {
                ++ cnt;
                v[cnt] = a * k;
                w[cnt] = b * k;
                s -= k;
                k *= 2;
            }
            if (s > 0) {
                ++ cnt;
                v[cnt] = a * s;
                w[cnt] = b * s;
            }
        }

        // 进行一遍01背包处理
        n = cnt;
        int[] f = new int[m + 1];
        for (int i = 1; i <= n; ++ i) {
            for (int j = m; j >= v[i]; -- j) {
                f[j] = Math.max(f[j], f[j - v[i]] + w[i]);
            }
        }

        System.out.println(f[m]);
    }
}

分组背包问题

import java.util.*;

public class Main{
    public static void main(String[] args) {
        int N = 110;
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), m = sc.nextInt();

        int[][] v = new int[N][N];
        int[][] w = new int[N][N];
        int[] s = new int[N];
        for (int i = 1; i <= n; ++ i) {
            s[i] = sc.nextInt();
            for (int j = 1; j <= s[i]; ++ j) {
                v[i][j] = sc.nextInt();
                w[i][j] = sc.nextInt();
            }
        }
        int[] f = new int[m + 1];
        for (int i = 1; i <= n; ++ i) {
            for (int j = m; j >= 1; -- j) {
                for (int k = 1; k <= s[i]; ++ k) {
                    if (j >= v[i][k])
                        f[j] = Math.max(f[j], f[j - v[i][k]] + w[i][k]);
                }
            }
        }
        System.out.println(f[m]);
    }
}

线性dp

数字三角形

🧩 自顶向下

import java.util.*;

public class Main{
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[][] t = new int[n][n];
        for (int i = 0; i < n; ++ i) {
            for (int j = 0; j <= i; ++ j) {
                t[i][j] = sc.nextInt();
            }
        }

        int[][] f = new int[n][n];
        f[0][0] = t[0][0];
        for (int i = 1; i < n; ++ i) {
            f[i][0] = t[i][0] + f[i - 1][0];
            for (int j = 1; j < i; ++ j) {
                f[i][j] = t[i][j] + Math.max(f[i - 1][j - 1], f[i - 1][j]);
            }
            f[i][i] = t[i][i] + f[i - 1][i - 1];
        }

        int ans = f[n - 1][0];
        for (int l : f[n - 1]) {
            ans = Math.max(ans, l);
        }
        System.out.println(ans);
    }
}

🧩 自底向上

import java.util.*;

public class Main{
    static int N = 510;
    public static void main (String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[][] f = new int[N][N];
        for (int i = 1; i <= n; ++ i) {
            for (int j = 1; j <= i; ++ j) {
                f[i][j] = sc.nextInt();
            }
        }
        for (int i = n - 1; i >= 1; -- i) {
            for (int j = 1; j <= i; ++ j) {
                f[i][j] += Math.max(f[i + 1][j + 1], f[i + 1][j]);
            }
        }
        System.out.println(f[1][1]);
    }
}

最长上升子序列

import java.util.*;

public class Main{
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] s = new int[n];
        for (int i = 0; i < n; ++ i) s[i] = sc.nextInt();
        int[] f = new int[n];
        int ans = 0;
        for (int i = 0; i < n; ++ i) {
            f[i] = 1;
            for (int j = 0; j < i; ++ j) {
                if (s[i] > s[j])
                    f[i] = Math.max(f[i], f[j] + 1);
            }
            ans = Math.max(ans, f[i]);
        }
        System.out.println(ans);
    }
}

最长上升子序列 II

import java.util.*;

public class Main{
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] s = new int[n];
        for (int i = 0; i < n; ++ i) s[i] = sc.nextInt();
        // 存长度为i的最长递增子序列的最后一个数的值
        int[] q = new int[n + 1];
        q[0] = Integer.MIN_VALUE;  // 哨兵
        int len = 0;
        for (int i = 0; i < n; ++ i) {
            int l = 0, r = len; 
            // 找到q中小于s[i]的最大值
            while (l < r) {
                int mid = l + r + 1 >> 1;
                if (q[mid] < s[i]) l = mid;
                else r = mid - 1;
            }
            q[l + 1] = s[i];
            len = Math.max(len, l + 1);
        }
        System.out.println(len);
    }
}

最长公共子序列

🧩 优化 & 贪心

import java.util.*;

public class Main{
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String[] nums = sc.nextLine().split(" ");
        int n = Integer.parseInt(nums[0]), m = Integer.parseInt(nums[1]);
        String a = sc.nextLine();
        String b = sc.nextLine();

        int[][] f = new int[n + 1][m + 1];
        int ans = 0;
        for (int i = 1; i <= n; ++ i) {
            for (int j = 1; j <= m; ++ j) {
                if (a.charAt(i - 1) == b.charAt(j - 1)) {
                    f[i][j] = f[i - 1][j - 1] + 1;
                } else {
                    f[i][j] = Math.max(f[i - 1][j], f[i][j - 1]);
                }
                ans = Math.max(ans, f[i][j]);
            }
        }
        System.out.println(ans);
    }
}

最短编辑距离

import java.util.*;

public class Main{
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = Integer.parseInt(sc.nextLine());
        String s1 = sc.nextLine();
        int m = Integer.parseInt(sc.nextLine());
        String s2 = sc.nextLine();
        // f[i][j] 表示将 1-i 编辑为 1-j 的最短编辑距离
        int[][] f = new int[n + 1][m + 1];
        for (int i = 0; i <= n; ++ i) f[i][0] = i;
        for (int i = 0; i <= m; ++ i) f[0][i] = i;

        for (int i = 1; i <= n; ++ i) {
            for (int j = 1; j <= m; ++ j) {
                f[i][j] = Math.min(f[i - 1][j] + 1, f[i][j - 1] + 1);
                if (s1.charAt(i - 1) == s2.charAt(j - 1)) {
                    f[i][j] = Math.min(f[i][j], f[i - 1][j - 1]);
                } else {
                    f[i][j] = Math.min(f[i][j], f[i - 1][j - 1] + 1);
                }
            }
        }
        System.out.println(f[n][m]);
    }
}

编辑距离

import java.util.*;

public class Main{
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String[] nums = sc.nextLine().split(" ");
        int n = Integer.parseInt(nums[0]), m = Integer.parseInt(nums[1]);
        String[] s = new String[n];
        for (int i = 0; i < n; ++ i) s[i] = sc.nextLine();

        while (m-- > 0) {
            String[] q = sc.nextLine().split(" ");
            int ans = 0;
            for (int i = 0; i < n; ++ i) {
                if (eval_distance(s[i], q[0]) <= Integer.parseInt(q[1]))
                    ++ ans;
            }
            System.out.println(ans);
        }
    }

    private static int eval_distance(String s1, String s2) {
        int n = s1.length(), m = s2.length();
        int[][] f = new int[n + 1][m + 1]; // f[i][j] 长度为 i 的 变为 j 的最少操作次数
        for (int i = 0; i <= n; ++ i) f[i][0] = i;
        for (int i = 0; i <= m; ++ i) f[0][i] = i;

        for (int i = 1; i <= n; ++ i) {
            for (int j = 1; j <= m; ++ j) {
                f[i][j] = Math.min(f[i - 1][j] + 1, f[i][j - 1] + 1);
                if (s1.charAt(i - 1) == s2.charAt(j - 1)) {
                    f[i][j] = Math.min(f[i][j], f[i - 1][j - 1]);
                } else {
                    f[i][j] = Math.min(f[i][j], f[i - 1][j - 1] + 1);
                }
            }
        }
        return f[n][m];
    }
}

区间dp

石子合并

import java.util.*;

public class Main{
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] s = new int[n + 1];
        for (int i = 1; i <= n; ++ i) {
            s[i] = sc.nextInt();
            s[i] += s[i - 1];
        }
        // 将 [i, j] 合并成一堆的方案的最小代价
        int[][] f = new int[n + 1][n + 1];

        for (int len = 2; len <= n; ++ len) {
            for (int i = 1; i + len - 1 <= n; ++ i) {
                int j = i + len - 1;
                f[i][j] = Integer.MAX_VALUE;
                for (int k = i; k < j; ++ k) {
                    f[i][j] = Math.min(f[i][j], f[i][k] + f[k + 1][j] + s[j] - s[i - 1]);
                }
            }
        }
        System.out.println(f[1][n]);
    }
}

(补)奇怪的打印机

class Solution {
    public int strangePrinter(String s) {
        int n = s.length();
        // f[i][j]表示[i, j]的需要打印的最小次数
        int[][] f = new int[n][n];
        for (int i = n - 1; i >= 0; -- i) {
            f[i][i] = 1; // 边界条件
            for (int j = i + 1; j < n; ++ j) {
                if (s.charAt(i) == s.charAt(j)) {
                    f[i][j] = f[i][j - 1];
                } else {
                    int min = Integer.MAX_VALUE;
                    for (int k = i; k < j; ++ k) {
                        min = Math.min(min, f[i][k] + f[k + 1][j]);
                    }
                    f[i][j] = min;
                }
            }
        }
        return f[0][n - 1];
    }
}

计数型dp

整数划分

🧩 完全背包 & 朴素写法

  • 状态表示:
    • f [ i ] [ j ] 表示只从 1 ~ i 中选,且总和等于 j 的方案数
  • 状态转移方程:
    • f [ i ] [ j ] = f [ i - 1 ] [ j ] + f [ i ] [ j - i ];
import java.util.*;

public class Main{
    static int mod = (int) (1e9 + 7);
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        // f[i][j] 表示从 [1, i] 中总和为 j 的方案数
        int[][] f = new int[n + 1][n + 1];
        for (int i = 1; i <= n; ++ i) {
            f[i][0] = 1;
            for (int j = 1; j <= n; ++ j) {
                f[i][j] = f[i - 1][j];
                if (j >= i)
                    f[i][j] = (f[i][j] + f[i][j - i]) % mod;
            }
        }
        System.out.println(f[n][n]) ;
    }
}

🧩 完全背包 & 优化写法

import java.util.*;

public class Main{
    static int mod = (int) (1e9 + 7);
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] f = new int[n + 1];
        f[0] = 1;
        // 完全背包问题
        for (int i = 1; i <= n; ++ i) {
            for (int j = i; j <= n; ++ j) {
                f[j] = (f[j] + f[j - i]) % mod;
            }
        }
        System.out.println(f[n]) ;
    }
}

🧩 另一种思维

  • 状态表示:
    • f [ i ] [ j ]表示总和为 i,总个数为 j 的方案数
  • 状态转移方程:
    • f [ i ] [ j ] = f [ i - 1 ] [ j - 1 ] + f [ i - j ] [ j ];
import java.util.*;

public class Main{
    static int mod = (int) (1e9 + 7);
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        // f[i][j] 表示总和为 i, 总个数为 j 的方案数
        int[][] f = new int[n + 1][n + 1];
        f[1][1] = 1;
        for (int i = 2; i <= n; ++ i) {
            for (int j = 1; j <= i; ++ j) {
                f[i][j] = (f[i - 1][j - 1] + f[i - j][j]) % mod;
            }
        }
        int ans = 0;
        for (int i = 1; i <= n; ++ i) ans = (ans + f[n][i]) % mod;
        System.out.println(ans) ;
    }
}

状态压缩DP

蒙德里安的梦想

最短Hamilton路径

数位统计DP



Try and fail, but don't fail to try.