0%

回溯算法详解

刷题的时候碰到一个很常见的算法-回溯法,看了一些博客,感觉这个讲得通俗易懂。转过来,后面有些自己的补充。
  回溯法(探索与回溯法)是一种选优搜索法,又称为试探法,按选优条件向前搜索,以达到目标。但当探索到某一步时,发现原先选择并不优或达不到目标,就退回一步重新选择,这种走不通就退回再走的技术为回溯法,而满足回溯条件的某个状态的点称为“回溯点”。

首先我们来看一道题目:
Combinations:Given two integers n and k,return all possible combinations of k numbersout of 1 … n. For example, If n = 4 and k =2, a solution is:

[
[2,4],
[3,4],
[2,3],
[1,2],
[1,3],
[1,4],
]
(做一个白话版的描述,给你两个整数 n和k,从1-n中选择k个数字的组合。比如n=4,那么从1,2,3,4中选取两个数字的组合,包括图上所述的四种。)
然后我们看看题目给出的框架:

1
2
3
4
public class Solution {
public List<List<Integer>> combine(int n, int k) {
}
}

要求返回的类型是List<List> 也就是说将所有可能的组合list(由整数构成)放入另一个list(由list构成)中。
现在进行套路教学:要求返回List<List>,那我就给你一个List<List>,因此

  1. 定义一个全局List<List> result=new ArrayList<List>();
  2. 定义一个辅助的方法(函数)public void backtracking(int n,int k, Listlist){}
    n k 总是要有的吧,加上这两个参数,前面提到List 是数字的组合,也是需要的吧,这三个是必须的,没问题吧。(可以尝试性地写参数,最后不需要的删除)
  3. 接着就是我们的重头戏了,如何实现这个算法?对于n=4,k=2,1,2,3,4中选2个数字,我们可以做如下尝试,加入先选择1,那我们只需要再选择一个数字,注意这时候k=1了(此时只需要选择1个数字啦)。当然,我们也可以先选择2,3 或者4,通俗化一点,我们可以选择(1-n)的所有数字,这个是可以用一个循环来描述?每次选择一个加入我们的链表list中,下一次只要再选择k-1个数字。那什么时候结束呢?当然是k<0的时候啦,这时候都选完了。
    有了上面的分析,我们可以开始填写public void backtracking(int n,int k, List list){}中的内容。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    public void backtracking(int n,int k,int start,List<Integer> list){  
    if(k<0) return;
    else if(k==0){ //k==0表示已经找到了k个数字的组合,这时候加入全局result中
    result.add(new ArrayList(list));
    }else{
    for(int i=start;i<=n;i++){
    list.add(i);//尝试性的加入i
    //开始回溯啦,下一次要找的数字减少一个所以用k-1,i+1见后面分析
    backtracking(n,k-1,i+1,list);
    //(留白,有用=。=)
    }
    }
    }
    观察一下上述代码,我们加入了一个start变量,它是i的起点。为什么要加入它呢?比如我们第一次加入了1,下一次搜索的时候还能再搜索1了么?肯定不可以啊!我们必须从他的下一个数字开始,也就是2 、3或者4啦。所以start就是一个开始标记这个很重要啦!
    这时候我们在主方法中加入backtracking(n,k,1,list);调试后发现答案不对啊!为什么我的答案比他长那么多?

回溯回溯当然要退回再走啦,你不退回,当然又臭又长了!所以我们要在刚才代码注释留白处加上退回语句。仔细分析刚才的过程,我们每次找到了1,2这一对答案以后,下一次希望2退出然后让3进来,1 3就是我们要找的下一个组合。如果不回退,找到了2 ,3又进来,找到了3,4又进来,所以就出现了我们的错误答案。正确的做法就是加上:list.remove(list.size()-1);他的作用就是每次清除一个空位 让后续元素加入。寻找成功,最后一个元素要退位,寻找不到,方法不可行,那么我们回退,也要移除最后一个元素。
所以完整的程序如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public class Solution {  
List<List<Integer>> result=new ArrayList<List<Integer>>();
public List<List<Integer>> combine(int n, int k) {
List<Integer> list=new ArrayList<Integer>();
backtracking(n,k,1,list);
return result;
}
public void backtracking(int n,int k,int start,List<Integer>list){
if(k<0) return ;
else if(k==0){
result.add(new ArrayList(list));
}else{
for(int i=start;i<=n;i++){
list.add(i);
backtracking(n,k-1,i+1,list);
list.remove(list.size()-1);
}
}
}
}

是不是有点想法了?那么我们操刀一下。
Combination Sum
Given a set ofcandidate numbers (C) and a target number (T), findall unique combinations in C where thecandidate numbers sums toT.
The same repeated numbermay be chosen from C unlimited numberof times.
Note:
All numbers (including target) will be positive integers.
The solution set must not contain duplicate combinations.
For example,given candidate set [2, 3, 6, 7] and target 7,
A solution set is:
[
[7],
[2,2, 3]
]
(容我啰嗦地白话下,给你一个正数数组candidate[],一个目标值target,寻找里面所有的不重复组合,让其和等于target,给你[2,3,6,7] 2+2+3=7 ,7=7,所以可能组合为[2,2,3],[7])
按照前述的套路走一遍:

1
2
3
4
5
6
7
8
9
10
public class Solution {
List<List<Integer>> result=new ArrayList<List<Integer>>();
public List<List<Integer>> combinationSum(int[] candidates,int target) {
Arrays.sort(candidates);
List<Integer> list=new ArrayList<Integer>();
return result;
}
public void backtracking(int[] candidates,int target,int start,){
}
}
  1. 全局List<List> result先定义
  2. 回溯backtracking方法要定义,数组candidates 目标target 开头start 辅助链表List list都加上。
  3. 分析算法:以[2,3,6,7] 每次尝试加入数组任何一个值,用循环来描述,表示依次选定一个值
    1
    2
    3
    for(inti=start;i<candidates.length;i++){
    list.add(candidates[i]);
    }
    接下来回溯方法再调用。比如第一次选了2,下次还能再选2是吧,所以每次start都可以从当前i开始(ps:如果不允许重复,从i+1开始)。第一次选择2,下一次要凑的数就不是7了,而是7-2,也就是5,一般化就是remain=target-candidates[i],所以回溯方法为:
    backtracking(candidates,target-candidates[i],i,list);
    然后加上退回语句:list.remove(list.size()-1);
    那么什么时候找到的解符合要求呢?自然是remain(注意区分初始的target)=0了,表示之前的组合恰好能凑出target。如果remain<0 表示凑的数太大了,组合不可行,要回退。当remain>0 说明凑的还不够,继续凑。
    所以完整方法如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    public class Solution {  
    List<List<Integer>> result=newArrayList<List<Integer>>();
    public List<List<Integer>>combinationSum(int[] candidates, int target) {
    Arrays.sort(candidates);//所给数组可能无序,排序保证解按照非递减组合
    List<Integer> list=newArrayList<Integer>();
    backtracking(candidates,target,0,list);//给定target,start=0表示从数组第一个开始
    return result;//返回解的组合链表
    }
    public void backtracking(int[]candidates,int target,int start,List<Integer> list){

    if(target<0) return;//凑过头了
    else if(target==0){
    result.add(newArrayList<>(list));//正好凑出答案,开心地加入解的链表
    }else{
    for(int i=start;i<candidates.length;i++){//循环试探每个数
    list.add(candidates[i]);//尝试加入
    //下一次凑target-candidates[i],允许重复,还是从i开始
    backtracking(candidates,target-candidates[i],i,list);
    list.remove(list.size()-1);//回退
    }
    }

    }
    }
    是不是觉得还是有迹可循的?下一篇博客将部分回溯算法拿出来,供大家更好地发现其中的套路。
    This structure might apply to many other backtracking questions, but here I am just going to demonstrate Subsets, Permutations, and Combination Sum.

Subsets : https://leetcode.com/problems/subsets/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public List<List<Integer>> subsets(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, 0);
return list;
}

private void backtrack(List<List<Integer>> list , List<Integer> tempList, int [] nums, int start){
list.add(new ArrayList<>(tempList));
for(int i = start; i < nums.length; i++){
tempList.add(nums[i]);
backtrack(list, tempList, nums, i + 1);
tempList.remove(tempList.size() - 1);
}
}

Subsets II (contains duplicates) : https://leetcode.com/problems/subsets-ii/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public List<List<Integer>> subsetsWithDup(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, 0);
return list;
}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums, int start){
list.add(new ArrayList<>(tempList));
for(int i = start; i < nums.length; i++){
if(i > start && nums[i] == nums[i-1]) continue; // skip duplicates
tempList.add(nums[i]);
backtrack(list, tempList, nums, i + 1);
tempList.remove(tempList.size() - 1);
}
}

Permutations : https://leetcode.com/problems/permutations/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public List<List<Integer>> permute(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
// Arrays.sort(nums); // not necessary
backtrack(list, new ArrayList<>(), nums);
return list;
}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums){
if(tempList.size() == nums.length){
list.add(new ArrayList<>(tempList));
} else{
for(int i = 0; i < nums.length; i++){
if(tempList.contains(nums[i])) continue; // element already exists, skip
tempList.add(nums[i]);
backtrack(list, tempList, nums);
tempList.remove(tempList.size() - 1);
}
}
}

Permutations II (contains duplicates) : https://leetcode.com/problems/permutations-ii/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public List<List<Integer>> permuteUnique(int[] nums) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, new boolean[nums.length]);
return list;
}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums, boolean [] used){
if(tempList.size() == nums.length){
list.add(new ArrayList<>(tempList));
} else{
for(int i = 0; i < nums.length; i++){
if(used[i] || i > 0 && nums[i] == nums[i-1] && !used[i - 1]) continue;
used[i] = true;
tempList.add(nums[i]);
backtrack(list, tempList, nums, used);
used[i] = false;
tempList.remove(tempList.size() - 1);
}
}
}

Combination Sum : https://leetcode.com/problems/combination-sum/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public List<List<Integer>> combinationSum(int[] nums, int target) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, target, 0);
return list;
}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums, int remain, int start){
if(remain < 0) return;
else if(remain == 0) list.add(new ArrayList<>(tempList));
else{
for(int i = start; i < nums.length; i++){
tempList.add(nums[i]);
backtrack(list, tempList, nums, remain - nums[i], i); // not i + 1 because we can reuse same elements
tempList.remove(tempList.size() - 1);
}
}
}

Combination Sum II (can’t reuse same element) : https://leetcode.com/problems/combination-sum-ii/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public List<List<Integer>> combinationSum2(int[] nums, int target) {
List<List<Integer>> list = new ArrayList<>();
Arrays.sort(nums);
backtrack(list, new ArrayList<>(), nums, target, 0);
return list;

}

private void backtrack(List<List<Integer>> list, List<Integer> tempList, int [] nums, int remain, int start){
if(remain < 0) return;
else if(remain == 0) list.add(new ArrayList<>(tempList));
else{
for(int i = start; i < nums.length; i++){
if(i > start && nums[i] == nums[i-1]) continue; // skip duplicates
tempList.add(nums[i]);
backtrack(list, tempList, nums, remain - nums[i], i + 1);
tempList.remove(tempList.size() - 1);
}
}
}

Palindrome Partitioning : https://leetcode.com/problems/palindrome-partitioning/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public List<List<String>> partition(String s) {
List<List<String>> list = new ArrayList<>();
backtrack(list, new ArrayList<>(), s, 0);
return list;
}

public void backtrack(List<List<String>> list, List<String> tempList, String s, int start){
if(start == s.length())
list.add(new ArrayList<>(tempList));
else{
for(int i = start; i < s.length(); i++){
if(isPalindrome(s, start, i)){
tempList.add(s.substring(start, i + 1));
backtrack(list, tempList, s, i + 1);
tempList.remove(tempList.size() - 1);
}
}
}
}

public boolean isPalindrome(String s, int low, int high){
while(low < high)
if(s.charAt(low++) != s.charAt(high--)) return false;
return true;
}