蒙特卡洛算法
1、前言
初次听到这个算法还是之前面试的时候,有人出了一个求阴影部分面积的题,类似于下图:
题目描述
题目的要求是求这个阴影的面积,搞得我当时都有点阴影了,并没有求出来。
后面别人说这道题用蒙特卡洛方法来计算,就是随机向这个正方形区域内投入大量的点,然后算在阴影部分的面积与正方形的比例,就可以得到结果。
首先,我们知道在一个坐标系中,圆的标准方程为:,圆心坐标为(a,b),半径为 r
阴影部分的点满足:
,因为这个圆的半径是1,圆心为(1,1),表示这个点需要在圆内;且 ,因为阴影部分的点不能在圆心为(0,0) 的圆内。
结合这两个条件,我们可以求阴影部分的面积:m / n = area / 2 * 2,编程实现为:
public double area(){
double n = 1000000;
double count = 0;
for(int i = 0; i < n; i++){
// x 范围 (0, 2)
double x = Math.random() * 2;
// x 范围 (0, 2)
double y = Math.random() * 2;
if((Math.pow(x - 1, 2) + Math.pow(y - 1, 2) <= 1 && Math.pow(x, 2) + Math.pow(y, 2) > 4)){
count++;
}
}
return count / n;
}
2、蒙特卡洛搜索
蒙特卡洛方法利用了大数定律,在次数够大的情况下,往往能无限逼近真实值,就是收敛比较慢。还有一点,如果这东西只能求圆周率、积分之类的,我觉得真的太 native 了。
下图是蒙特卡洛搜索,简单来说,搜索全靠蒙。
蒙特卡洛搜索
上图中每个节点代表一个局面。而 A/B 代表这个节点被访问 B 次,黑棋胜利了 A 次。例如一开始的根节点是 12/21,代表总共模拟了 21 次,黑棋胜利了 12 次。
我们将不断重复一个过程(很多万次):
- 1.这个过程的第一步叫选择(Selection)。从根节点往下走,每次都选一个“最值得看的子节点”(具体规则稍后说),直到来到一个“存在未扩展的子节点”的节点,如图中的 3/3 节点。什么叫做“存在未扩展的子节点”,其实就是指这个局面存在未走过的后续着法。
- 2.第二步叫扩展(Expansion),我们给这个节点加上一个 0/0 子节点,对应之前所说的“未扩展的子节点”,就是还没有试过的一个着法。
- 3.第三步是模拟(Simluation)。从上面这个没有试过的着法开始,用快速走子策略(Rollout policy)走到底,得到一个胜负结果。按照普遍的观点,快速走子策略适合选择一个棋力很弱但走子很快的策略。因为如果这个策略走得慢(比如用 AlphaGo 的策略网络走棋),虽然棋力会更强,结果会更准确,但由于耗时多了,在单位时间内的模拟次数就少了,所以不一定会棋力更强,有可能会更弱。这也是为什么我们一般只模拟一次,因为如果模拟多次,虽然更准确,但更慢。
- 4.第四步是回溯(Backpropagation)。把模拟的结果加到它的所有父节点上。例如第三步模拟的结果是 0/1(代表黑棋失败),那么就把这个节点的所有父节点加上 0/1。
看了上面的过程,可能我们比较迷惑,所以接下来我将会结合蒙特卡洛树搜索用大白话说一下井字棋游戏:
- 1.第一步是选择,我们从根节点开始,选择一个 uct 值最大(用 uct 公式计算,如果一个节点未被访问过默认 uct 最大,这种措施来防止一直选第一次选过的节点,俗称雨露均沾)的子节点,然后从这个子节点选中 uct 最大的子节点,直到选到叶子结点。对应着井字棋,就是从根节点选中一个节点,因为根节点没有孩子,所以选择根节点(后续的迭代中,根节点有子节点,就是根据根据上面的规则来)。
- 2.第二步是扩展,对选中的节点进行扩展,对应于井字棋就是对棋盘上剩下的步骤走一步有多少种走法,然后构建多少种节点链接到选择的节点后面。
- 3.第三步是模拟,从选中的节点随机选中一个字节点,然后根据这个子节点的走法,按照随机走棋的原则,自己与对手都随机走棋,然后直到井字棋结束(一方胜利或者平局)。注意的是,从这个子节点开始的对棋是虚拟的,后面的操作不构造新的节点在这个子节点下,只是为了快速模拟胜负而已。
- 4.第四步是回溯,把模拟的节点从子节点开始,不断的传播到父节点上,直到到达跟节点。
上述的过程不断循环。那算法什么时候终止呢?取决于你让他什么时候终止,或者棋盘到达终态。上诉过程结束后,然后从根节点的开始,选择根节点的字节点中,访问频率最高的节点(如果节点未访问过,则取该节点),然后使用这中下法。
上面流程的图类似于这样:
井字棋图
至于代码,我还是给出来吧,这是一个老外写的井字游戏的代码,我负责抄过来。
Node.java
package com.example.demo.easy.mengtekaluo;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
public class Node {
State state;
Node parent;
List<Node> childArray;
public Node() {
this.state = new State();
childArray = new ArrayList<>();
}
public Node(State state) {
this.state = state;
childArray = new ArrayList<>();
}
public Node(State state, Node parent, List<Node> childArray) {
this.state = state;
this.parent = parent;
this.childArray = childArray;
}
public Node(Node node) {
this.childArray = new ArrayList<>();
this.state = new State(node.getState());
if (node.getParent() != null)
this.parent = node.getParent();
List<Node> childArray = node.getChildArray();
for (Node child : childArray) {
this.childArray.add(new Node(child));
}
}
public State getState() {
return state;
}
public void setState(State state) {
this.state = state;
}
public Node getParent() {
return parent;
}
public void setParent(Node parent) {
this.parent = parent;
}
public List<Node> getChildArray() {
return childArray;
}
public void setChildArray(List<Node> childArray) {
this.childArray = childArray;
}
public Node getRandomChildNode() {
int noOfPossibleMoves = this.childArray.size();
int selectRandom = (int) (Math.random() * noOfPossibleMoves);
return this.childArray.get(selectRandom);
}
public Node getChildWithMaxScore() {
return Collections.max(this.childArray, Comparator.comparing(c -> {
return c.getState().getVisitCount();
}));
}
}
State.java
package com.example.demo.easy.mengtekaluo;
import java.util.ArrayList;
import java.util.List;
public class State {
private Board board;
private int playerNo;
private int visitCount;
private double winScore;
public State() {
board = new Board();
}
public State(State state) {
this.board = new Board(state.getBoard());
this.playerNo = state.getPlayerNo();
this.visitCount = state.getVisitCount();
this.winScore = state.getWinScore();
}
public State(Board board) {
this.board = new Board(board);
}
Board getBoard() {
return board;
}
void setBoard(Board board) {
this.board = board;
}
int getPlayerNo() {
return playerNo;
}
void setPlayerNo(int playerNo) {
this.playerNo = playerNo;
}
int getOpponent() {
return 3 - playerNo;
}
public int getVisitCount() {
return visitCount;
}
public void setVisitCount(int visitCount) {
this.visitCount = visitCount;
}
double getWinScore() {
return winScore;
}
void setWinScore(double winScore) {
this.winScore = winScore;
}
public List<State> getAllPossibleStates() {
List<State> possibleStates = new ArrayList<>();
List<Position> availablePositions = this.board.getEmptyPositions();
availablePositions.forEach(p -> {
State newState = new State(this.board);
newState.setPlayerNo(3 - this.playerNo);
newState.getBoard().performMove(newState.getPlayerNo(), p);
possibleStates.add(newState);
});
return possibleStates;
}
void incrementVisit() {
this.visitCount++;
}
void addScore(double score) {
if (this.winScore != Integer.MIN_VALUE)
this.winScore += score;
}
void randomPlay() {
List<Position> availablePositions = this.board.getEmptyPositions();
int totalPossibilities = availablePositions.size();
int selectRandom = (int) (Math.random() * totalPossibilities);
this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
}
void togglePlayer() {
this.playerNo = 3 - this.playerNo;
}
}
Board.java
package com.example.demo.easy.mengtekaluo;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class Board {
int[][] boardValues;
int totalMoves;
public static final int DEFAULT_BOARD_SIZE = 3;
public static final int IN_PROGRESS = -1;
public static final int DRAW = 0;
public static final int P1 = 1;
public static final int P2 = 2;
public Board() {
boardValues = new int[DEFAULT_BOARD_SIZE][DEFAULT_BOARD_SIZE];
}
public Board(int boardSize) {
boardValues = new int[boardSize][boardSize];
}
public Board(int[][] boardValues) {
this.boardValues = boardValues;
}
public Board(int[][] boardValues, int totalMoves) {
this.boardValues = boardValues;
this.totalMoves = totalMoves;
}
public Board(Board board) {
int boardLength = board.getBoardValues().length;
this.boardValues = new int[boardLength][boardLength];
int[][] boardValues = board.getBoardValues();
int n = boardValues.length;
for (int i = 0; i < n; i++) {
int m = boardValues[i].length;
for (int j = 0; j < m; j++) {
this.boardValues[i][j] = boardValues[i][j];
}
}
}
public void performMove(int player, Position p) {
this.totalMoves++;
boardValues[p.getX()][p.getY()] = player;
}
public int[][] getBoardValues() {
return boardValues;
}
public void setBoardValues(int[][] boardValues) {
this.boardValues = boardValues;
}
public int checkStatus() {
int boardSize = boardValues.length;
int maxIndex = boardSize - 1;
int[] diag1 = new int[boardSize];
int[] diag2 = new int[boardSize];
for (int i = 0; i < boardSize; i++) {
int[] row = boardValues[i];
int[] col = new int[boardSize];
for (int j = 0; j < boardSize; j++) {
col[j] = boardValues[j][i];
}
int checkRowForWin = checkForWin(row);
if(checkRowForWin!=0)
return checkRowForWin;
int checkColForWin = checkForWin(col);
if(checkColForWin!=0)
return checkColForWin;
diag1[i] = boardValues[i][i];
diag2[i] = boardValues[maxIndex - i][i];
}
int checkDia1gForWin = checkForWin(diag1);
if(checkDia1gForWin!=0)
return checkDia1gForWin;
int checkDiag2ForWin = checkForWin(diag2);
if(checkDiag2ForWin!=0)
return checkDiag2ForWin;
if (getEmptyPositions().size() > 0)
return IN_PROGRESS;
else
return DRAW;
}
private int checkForWin(int[] row) {
boolean isEqual = true;
int size = row.length;
int previous = row[0];
for (int i = 0; i < size; i++) {
if (previous != row[i]) {
isEqual = false;
break;
}
previous = row[i];
}
if(isEqual)
return previous;
else
return 0;
}
public void printBoard() {
int size = this.boardValues.length;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
System.out.print(boardValues[i][j] + " ");
}
System.out.println();
}
}
public List<Position> getEmptyPositions() {
int size = this.boardValues.length;
List<Position> emptyPositions = new ArrayList<>();
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
if (boardValues[i][j] == 0)
emptyPositions.add(new Position(i, j));
}
}
return emptyPositions;
}
public void printStatus() {
switch (this.checkStatus()) {
case P1:
System.out.println("Player 1 wins");
break;
case P2:
System.out.println("Player 2 wins");
break;
case DRAW:
System.out.println("Game Draw");
break;
case IN_PROGRESS:
System.out.println("Game In Progress");
break;
}
}
}
Position.java
package com.example.demo.easy.mengtekaluo;
public class Position {
int x;
int y;
public Position() {
}
public Position(int x, int y) {
this.x = x;
this.y = y;
}
public int getX() {
return x;
}
public void setX(int x) {
this.x = x;
}
public int getY() {
return y;
}
public void setY(int y) {
this.y = y;
}
}
Tree.java
package com.example.demo.easy.mengtekaluo;
public class Tree {
Node root;
public Tree() {
root = new Node();
}
public Tree(Node root) {
this.root = root;
}
public Node getRoot() {
return root;
}
public void setRoot(Node root) {
this.root = root;
}
public void addChild(Node parent, Node child) {
parent.getChildArray().add(child);
}
}
UCT.java
package com.example.demo.easy.mengtekaluo;
import java.util.Collections;
import java.util.Comparator;
public class UCT {
public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
// 没有访问过则取该节点
if (nodeVisit == 0) {
return Integer.MAX_VALUE;
}
return (nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
}
static Node findBestNodeWithUCT(Node node) {
int parentVisit = node.getState().getVisitCount();
return Collections.max(
node.getChildArray(),
Comparator.comparing(c -> uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount())));
}
}
MonteCarloTreeSearch.java
package com.example.demo.easy.mengtekaluo;
import java.util.List;
public class MonteCarloTreeSearch {
private static final int WIN_SCORE = 10;
private int level;
private int opponent;
public MonteCarloTreeSearch() {
this.level = 3;
}
public int getLevel() {
return level;
}
public void setLevel(int level) {
this.level = level;
}
private int getMillisForCurrentLevel() {
return 2 * (this.level - 1) + 1;
}
public Board findNextMove(Board board, int playerNo) {
long start = System.currentTimeMillis();
long end = start + 1000 * getMillisForCurrentLevel();
opponent = 3 - playerNo;
Tree tree = new Tree();
Node rootNode = tree.getRoot();
rootNode.getState().setBoard(board);
rootNode.getState().setPlayerNo(opponent);
while (System.currentTimeMillis() < end) {
// Phase 1 - Selection
Node promisingNode = selectPromisingNode(rootNode);
// Phase 2 - Expansion
if (promisingNode.getState().getBoard().checkStatus() == Board.IN_PROGRESS)
expandNode(promisingNode);
// Phase 3 - Simulation
Node nodeToExplore = promisingNode;
if (promisingNode.getChildArray().size() > 0) {
nodeToExplore = promisingNode.getRandomChildNode();
}
int playoutResult = simulateRandomPlayout(nodeToExplore);
// Phase 4 - Update
backPropogation(nodeToExplore, playoutResult);
}
Node winnerNode = rootNode.getChildWithMaxScore();
tree.setRoot(winnerNode);
return winnerNode.getState().getBoard();
}
private Node selectPromisingNode(Node rootNode) {
Node node = rootNode;
while (node.getChildArray().size() != 0) {
node = UCT.findBestNodeWithUCT(node);
}
return node;
}
private void expandNode(Node node) {
List<State> possibleStates = node.getState().getAllPossibleStates();
possibleStates.forEach(state -> {
Node newNode = new Node(state);
newNode.setParent(node);
newNode.getState().setPlayerNo(node.getState().getOpponent());
node.getChildArray().add(newNode);
});
}
private void backPropogation(Node nodeToExplore, int playerNo) {
Node tempNode = nodeToExplore;
while (tempNode != null) {
tempNode.getState().incrementVisit();
if (tempNode.getState().getPlayerNo() == playerNo)
tempNode.getState().addScore(WIN_SCORE);
tempNode = tempNode.getParent();
}
}
private int simulateRandomPlayout(Node node) {
Node tempNode = new Node(node);
State tempState = tempNode.getState();
int boardStatus = tempState.getBoard().checkStatus();
if (boardStatus == opponent) {
tempNode.getParent().getState().setWinScore(Integer.MIN_VALUE);
return boardStatus;
}
while (boardStatus == Board.IN_PROGRESS) {
tempState.togglePlayer();
tempState.randomPlay();
boardStatus = tempState.getBoard().checkStatus();
}
return boardStatus;
}
}
MCTSUnitTest.java
package com.example.demo.easy.mengtekaluo;
import org.junit.Before;
import org.junit.Test;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class MCTSUnitTest {
private Tree gameTree;
private MonteCarloTreeSearch mcts;
@Before
public void initGameTree() {
gameTree = new Tree();
mcts = new MonteCarloTreeSearch();
}
@Test
public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
double uctValue = 15.79;
assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
}
@Test
public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
State initState = gameTree.getRoot().getState();
List<State> possibleStates = initState.getAllPossibleStates();
assertTrue(possibleStates.size() > 0);
}
@Test
public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
Board board = new Board();
int initAvailablePositions = board.getEmptyPositions().size();
board.performMove(Board.P1, new Position(1, 1));
int availablePositions = board.getEmptyPositions().size();
assertTrue(initAvailablePositions > availablePositions);
}
@Test
public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
Board board = new Board();
int player = Board.P1;
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
for (int i = 0; i < totalMoves; i++) {
board = mcts.findNextMove(board, player);
board.printBoard();
if (board.checkStatus() != -1) {
break;
}
player = 3 - player;
}
int winStatus = board.checkStatus();
assertEquals(winStatus, Board.DRAW);
}
@Test
public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
Board board = new Board();
MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
mcts1.setLevel(1);
MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
mcts3.setLevel(3);
int player = Board.P1;
int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
for (int i = 0; i < totalMoves; i++) {
if (player == Board.P1)
board = mcts3.findNextMove(board, player);
else
board = mcts1.findNextMove(board, player);
board.printBoard();
if (board.checkStatus() != -1) {
break;
}
player = 3 - player;
}
int winStatus = board.checkStatus();
assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
}
}
3、后记
本来我只是想知道蒙特卡洛树到底是怎样搜索的,想更了解一下 dfs 之类的,没想到说了一堆算法,好想还是理解不了。
4、参考资料
https://zhuanlan.zhihu.com/p/53948964
https://zhuanlan.zhihu.com/p/34990220