蒙特卡洛算法

2021-07-15  本文已影响0人  放开那个BUG

1、前言

初次听到这个算法还是之前面试的时候,有人出了一个求阴影部分面积的题,类似于下图:


题目描述

题目的要求是求这个阴影的面积,搞得我当时都有点阴影了,并没有求出来。

后面别人说这道题用蒙特卡洛方法来计算,就是随机向这个正方形区域内投入大量的点,然后算在阴影部分的面积与正方形的比例,就可以得到结果。

首先,我们知道在一个坐标系中,圆的标准方程为:r^2 = (x - a)^2 + (y - a)^2,圆心坐标为(a,b),半径为 r
阴影部分的点满足:
(x - 1)^2 + (y - 1)^2 <= 1,因为这个圆的半径是1,圆心为(1,1),表示这个点需要在圆内;且 x^2 + y^2 > 2^2,因为阴影部分的点不能在圆心为(0,0) 的圆内。

结合这两个条件,我们可以求阴影部分的面积:m / n = area / 2 * 2,编程实现为:

public double area(){
        double n = 1000000;
        double count = 0;

        for(int i = 0; i < n; i++){
            // x 范围 (0, 2)
            double x = Math.random() * 2;
            // x 范围 (0, 2)
            double y = Math.random() * 2;

            if((Math.pow(x - 1, 2) + Math.pow(y - 1, 2) <= 1 && Math.pow(x, 2) + Math.pow(y, 2) > 4)){
                count++;
            }
        }

        return count / n;
    }

2、蒙特卡洛搜索

蒙特卡洛方法利用了大数定律,在次数够大的情况下,往往能无限逼近真实值,就是收敛比较慢。还有一点,如果这东西只能求圆周率、积分之类的,我觉得真的太 native 了。

下图是蒙特卡洛搜索,简单来说,搜索全靠蒙。


蒙特卡洛搜索

上图中每个节点代表一个局面。而 A/B 代表这个节点被访问 B 次,黑棋胜利了 A 次。例如一开始的根节点是 12/21,代表总共模拟了 21 次,黑棋胜利了 12 次。

我们将不断重复一个过程(很多万次):

  • 1.这个过程的第一步叫选择(Selection)。从根节点往下走,每次都选一个“最值得看的子节点”(具体规则稍后说),直到来到一个“存在未扩展的子节点”的节点,如图中的 3/3 节点。什么叫做“存在未扩展的子节点”,其实就是指这个局面存在未走过的后续着法。
  • 2.第二步叫扩展(Expansion),我们给这个节点加上一个 0/0 子节点,对应之前所说的“未扩展的子节点”,就是还没有试过的一个着法。
  • 3.第三步是模拟(Simluation)。从上面这个没有试过的着法开始,用快速走子策略(Rollout policy)走到底,得到一个胜负结果。按照普遍的观点,快速走子策略适合选择一个棋力很弱但走子很快的策略。因为如果这个策略走得慢(比如用 AlphaGo 的策略网络走棋),虽然棋力会更强,结果会更准确,但由于耗时多了,在单位时间内的模拟次数就少了,所以不一定会棋力更强,有可能会更弱。这也是为什么我们一般只模拟一次,因为如果模拟多次,虽然更准确,但更慢。
  • 4.第四步是回溯(Backpropagation)。把模拟的结果加到它的所有父节点上。例如第三步模拟的结果是 0/1(代表黑棋失败),那么就把这个节点的所有父节点加上 0/1。

看了上面的过程,可能我们比较迷惑,所以接下来我将会结合蒙特卡洛树搜索用大白话说一下井字棋游戏:

  • 1.第一步是选择,我们从根节点开始,选择一个 uct 值最大(用 uct 公式计算,如果一个节点未被访问过默认 uct 最大,这种措施来防止一直选第一次选过的节点,俗称雨露均沾)的子节点,然后从这个子节点选中 uct 最大的子节点,直到选到叶子结点。对应着井字棋,就是从根节点选中一个节点,因为根节点没有孩子,所以选择根节点(后续的迭代中,根节点有子节点,就是根据根据上面的规则来)。
  • 2.第二步是扩展,对选中的节点进行扩展,对应于井字棋就是对棋盘上剩下的步骤走一步有多少种走法,然后构建多少种节点链接到选择的节点后面。
  • 3.第三步是模拟,从选中的节点随机选中一个字节点,然后根据这个子节点的走法,按照随机走棋的原则,自己与对手都随机走棋,然后直到井字棋结束(一方胜利或者平局)。注意的是,从这个子节点开始的对棋是虚拟的,后面的操作不构造新的节点在这个子节点下,只是为了快速模拟胜负而已。
  • 4.第四步是回溯,把模拟的节点从子节点开始,不断的传播到父节点上,直到到达跟节点。

上述的过程不断循环。那算法什么时候终止呢?取决于你让他什么时候终止,或者棋盘到达终态。上诉过程结束后,然后从根节点的开始,选择根节点的字节点中,访问频率最高的节点(如果节点未访问过,则取该节点),然后使用这中下法。

上面流程的图类似于这样:


井字棋图

至于代码,我还是给出来吧,这是一个老外写的井字游戏的代码,我负责抄过来。
Node.java

package com.example.demo.easy.mengtekaluo;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

public class Node {
    State state;
    Node parent;
    List<Node> childArray;

    public Node() {
        this.state = new State();
        childArray = new ArrayList<>();
    }

    public Node(State state) {
        this.state = state;
        childArray = new ArrayList<>();
    }

    public Node(State state, Node parent, List<Node> childArray) {
        this.state = state;
        this.parent = parent;
        this.childArray = childArray;
    }

    public Node(Node node) {
        this.childArray = new ArrayList<>();
        this.state = new State(node.getState());
        if (node.getParent() != null)
            this.parent = node.getParent();
        List<Node> childArray = node.getChildArray();
        for (Node child : childArray) {
            this.childArray.add(new Node(child));
        }
    }

    public State getState() {
        return state;
    }

    public void setState(State state) {
        this.state = state;
    }

    public Node getParent() {
        return parent;
    }

    public void setParent(Node parent) {
        this.parent = parent;
    }

    public List<Node> getChildArray() {
        return childArray;
    }

    public void setChildArray(List<Node> childArray) {
        this.childArray = childArray;
    }

    public Node getRandomChildNode() {
        int noOfPossibleMoves = this.childArray.size();
        int selectRandom = (int) (Math.random() * noOfPossibleMoves);
        return this.childArray.get(selectRandom);
    }

    public Node getChildWithMaxScore() {
        return Collections.max(this.childArray, Comparator.comparing(c -> {
            return c.getState().getVisitCount();
        }));
    }

}

State.java

package com.example.demo.easy.mengtekaluo;

import java.util.ArrayList;
import java.util.List;

public class State {
    private Board board;
    private int playerNo;
    private int visitCount;
    private double winScore;

    public State() {
        board = new Board();
    }

    public State(State state) {
        this.board = new Board(state.getBoard());
        this.playerNo = state.getPlayerNo();
        this.visitCount = state.getVisitCount();
        this.winScore = state.getWinScore();
    }

    public State(Board board) {
        this.board = new Board(board);
    }

    Board getBoard() {
        return board;
    }

    void setBoard(Board board) {
        this.board = board;
    }

    int getPlayerNo() {
        return playerNo;
    }

    void setPlayerNo(int playerNo) {
        this.playerNo = playerNo;
    }

    int getOpponent() {
        return 3 - playerNo;
    }

    public int getVisitCount() {
        return visitCount;
    }

    public void setVisitCount(int visitCount) {
        this.visitCount = visitCount;
    }

    double getWinScore() {
        return winScore;
    }

    void setWinScore(double winScore) {
        this.winScore = winScore;
    }

    public List<State> getAllPossibleStates() {
        List<State> possibleStates = new ArrayList<>();
        List<Position> availablePositions = this.board.getEmptyPositions();
        availablePositions.forEach(p -> {
            State newState = new State(this.board);
            newState.setPlayerNo(3 - this.playerNo);
            newState.getBoard().performMove(newState.getPlayerNo(), p);
            possibleStates.add(newState);
        });
        return possibleStates;
    }

    void incrementVisit() {
        this.visitCount++;
    }

    void addScore(double score) {
        if (this.winScore != Integer.MIN_VALUE)
            this.winScore += score;
    }

    void randomPlay() {
        List<Position> availablePositions = this.board.getEmptyPositions();
        int totalPossibilities = availablePositions.size();
        int selectRandom = (int) (Math.random() * totalPossibilities);
        this.board.performMove(this.playerNo, availablePositions.get(selectRandom));
    }

    void togglePlayer() {
        this.playerNo = 3 - this.playerNo;
    }
}

Board.java

package com.example.demo.easy.mengtekaluo;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class Board {
    int[][] boardValues;
    int totalMoves;

    public static final int DEFAULT_BOARD_SIZE = 3;

    public static final int IN_PROGRESS = -1;
    public static final int DRAW = 0;
    public static final int P1 = 1;
    public static final int P2 = 2;

    public Board() {
        boardValues = new int[DEFAULT_BOARD_SIZE][DEFAULT_BOARD_SIZE];
    }

    public Board(int boardSize) {
        boardValues = new int[boardSize][boardSize];
    }

    public Board(int[][] boardValues) {
        this.boardValues = boardValues;
    }

    public Board(int[][] boardValues, int totalMoves) {
        this.boardValues = boardValues;
        this.totalMoves = totalMoves;
    }

    public Board(Board board) {
        int boardLength = board.getBoardValues().length;
        this.boardValues = new int[boardLength][boardLength];
        int[][] boardValues = board.getBoardValues();
        int n = boardValues.length;
        for (int i = 0; i < n; i++) {
            int m = boardValues[i].length;
            for (int j = 0; j < m; j++) {
                this.boardValues[i][j] = boardValues[i][j];
            }
        }
    }

    public void performMove(int player, Position p) {
        this.totalMoves++;
        boardValues[p.getX()][p.getY()] = player;
    }

    public int[][] getBoardValues() {
        return boardValues;
    }

    public void setBoardValues(int[][] boardValues) {
        this.boardValues = boardValues;
    }

    public int checkStatus() {
        int boardSize = boardValues.length;
        int maxIndex = boardSize - 1;
        int[] diag1 = new int[boardSize];
        int[] diag2 = new int[boardSize];
        
        for (int i = 0; i < boardSize; i++) {
            int[] row = boardValues[i];
            int[] col = new int[boardSize];
            for (int j = 0; j < boardSize; j++) {
                col[j] = boardValues[j][i];
            }
            
            int checkRowForWin = checkForWin(row);
            if(checkRowForWin!=0)
                return checkRowForWin;
            
            int checkColForWin = checkForWin(col);
            if(checkColForWin!=0)
                return checkColForWin;
            
            diag1[i] = boardValues[i][i];
            diag2[i] = boardValues[maxIndex - i][i];
        }

        int checkDia1gForWin = checkForWin(diag1);
        if(checkDia1gForWin!=0)
            return checkDia1gForWin;
        
        int checkDiag2ForWin = checkForWin(diag2);
        if(checkDiag2ForWin!=0)
            return checkDiag2ForWin;
        
        if (getEmptyPositions().size() > 0)
            return IN_PROGRESS;
        else
            return DRAW;
    }

    private int checkForWin(int[] row) {
        boolean isEqual = true;
        int size = row.length;
        int previous = row[0];
        for (int i = 0; i < size; i++) {
            if (previous != row[i]) {
                isEqual = false;
                break;
            }
            previous = row[i];
        }
        if(isEqual)
            return previous;
        else
            return 0;
    }

    public void printBoard() {
        int size = this.boardValues.length;
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                System.out.print(boardValues[i][j] + " ");
            }
            System.out.println();
        }
    }

    public List<Position> getEmptyPositions() {
        int size = this.boardValues.length;
        List<Position> emptyPositions = new ArrayList<>();
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                if (boardValues[i][j] == 0)
                    emptyPositions.add(new Position(i, j));
            }
        }
        return emptyPositions;
    }

    public void printStatus() {
        switch (this.checkStatus()) {
        case P1:
            System.out.println("Player 1 wins");
            break;
        case P2:
            System.out.println("Player 2 wins");
            break;
        case DRAW:
            System.out.println("Game Draw");
            break;
        case IN_PROGRESS:
            System.out.println("Game In Progress");
            break;
        }
    }
}

Position.java

package com.example.demo.easy.mengtekaluo;

public class Position {
    int x;
    int y;

    public Position() {
    }

    public Position(int x, int y) {
        this.x = x;
        this.y = y;
    }

    public int getX() {
        return x;
    }

    public void setX(int x) {
        this.x = x;
    }

    public int getY() {
        return y;
    }

    public void setY(int y) {
        this.y = y;
    }

}

Tree.java

package com.example.demo.easy.mengtekaluo;

public class Tree {
    Node root;

    public Tree() {
        root = new Node();
    }

    public Tree(Node root) {
        this.root = root;
    }

    public Node getRoot() {
        return root;
    }

    public void setRoot(Node root) {
        this.root = root;
    }

    public void addChild(Node parent, Node child) {
        parent.getChildArray().add(child);
    }

}

UCT.java

package com.example.demo.easy.mengtekaluo;

import java.util.Collections;
import java.util.Comparator;

public class UCT {

    public static double uctValue(int totalVisit, double nodeWinScore, int nodeVisit) {
         // 没有访问过则取该节点
        if (nodeVisit == 0) {
            return Integer.MAX_VALUE;
        }
        return (nodeWinScore / (double) nodeVisit) + 1.41 * Math.sqrt(Math.log(totalVisit) / (double) nodeVisit);
    }

    static Node findBestNodeWithUCT(Node node) {
        int parentVisit = node.getState().getVisitCount();
        return Collections.max(
          node.getChildArray(),
          Comparator.comparing(c -> uctValue(parentVisit, c.getState().getWinScore(), c.getState().getVisitCount())));
    }
}

MonteCarloTreeSearch.java

package com.example.demo.easy.mengtekaluo;

import java.util.List;

public class MonteCarloTreeSearch {

    private static final int WIN_SCORE = 10;
    private int level;
    private int opponent;

    public MonteCarloTreeSearch() {
        this.level = 3;
    }

    public int getLevel() {
        return level;
    }

    public void setLevel(int level) {
        this.level = level;
    }

    private int getMillisForCurrentLevel() {
        return 2 * (this.level - 1) + 1;
    }

    public Board findNextMove(Board board, int playerNo) {
        long start = System.currentTimeMillis();
        long end = start + 1000 * getMillisForCurrentLevel();

        opponent = 3 - playerNo;
        Tree tree = new Tree();
        Node rootNode = tree.getRoot();
        rootNode.getState().setBoard(board);
        rootNode.getState().setPlayerNo(opponent);

        while (System.currentTimeMillis() < end) {
            // Phase 1 - Selection
            Node promisingNode = selectPromisingNode(rootNode);
            // Phase 2 - Expansion
            if (promisingNode.getState().getBoard().checkStatus() == Board.IN_PROGRESS)
                expandNode(promisingNode);

            // Phase 3 - Simulation
            Node nodeToExplore = promisingNode;
            if (promisingNode.getChildArray().size() > 0) {
                nodeToExplore = promisingNode.getRandomChildNode();
            }
            int playoutResult = simulateRandomPlayout(nodeToExplore);
            // Phase 4 - Update
            backPropogation(nodeToExplore, playoutResult);
        }

        Node winnerNode = rootNode.getChildWithMaxScore();
        tree.setRoot(winnerNode);
        return winnerNode.getState().getBoard();
    }

    private Node selectPromisingNode(Node rootNode) {
        Node node = rootNode;
        while (node.getChildArray().size() != 0) {
            node = UCT.findBestNodeWithUCT(node);
        }
        return node;
    }

    private void expandNode(Node node) {
        List<State> possibleStates = node.getState().getAllPossibleStates();
        possibleStates.forEach(state -> {
            Node newNode = new Node(state);
            newNode.setParent(node);
            newNode.getState().setPlayerNo(node.getState().getOpponent());
            node.getChildArray().add(newNode);
        });
    }

    private void backPropogation(Node nodeToExplore, int playerNo) {
        Node tempNode = nodeToExplore;
        while (tempNode != null) {
            tempNode.getState().incrementVisit();
            if (tempNode.getState().getPlayerNo() == playerNo)
                tempNode.getState().addScore(WIN_SCORE);
            tempNode = tempNode.getParent();
        }
    }

    private int simulateRandomPlayout(Node node) {
        Node tempNode = new Node(node);
        State tempState = tempNode.getState();
        int boardStatus = tempState.getBoard().checkStatus();

        if (boardStatus == opponent) {
            tempNode.getParent().getState().setWinScore(Integer.MIN_VALUE);
            return boardStatus;
        }
        while (boardStatus == Board.IN_PROGRESS) {
            tempState.togglePlayer();
            tempState.randomPlay();
            boardStatus = tempState.getBoard().checkStatus();
        }

        return boardStatus;
    }

}

MCTSUnitTest.java

package com.example.demo.easy.mengtekaluo;

import org.junit.Before;
import org.junit.Test;

import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class MCTSUnitTest {
    private Tree gameTree;
    private MonteCarloTreeSearch mcts;

    @Before
    public void initGameTree() {
        gameTree = new Tree();
        mcts = new MonteCarloTreeSearch();
    }

    @Test
    public void givenStats_whenGetUCTForNode_thenUCTMatchesWithManualData() {
        double uctValue = 15.79;
        assertEquals(UCT.uctValue(600, 300, 20), uctValue, 0.01);
    }

    @Test
    public void giveninitBoardState_whenGetAllPossibleStates_thenNonEmptyList() {
        State initState = gameTree.getRoot().getState();
        List<State> possibleStates = initState.getAllPossibleStates();
        assertTrue(possibleStates.size() > 0);
    }

    @Test
    public void givenEmptyBoard_whenPerformMove_thenLessAvailablePossitions() {
        Board board = new Board();
        int initAvailablePositions = board.getEmptyPositions().size();
        board.performMove(Board.P1, new Position(1, 1));
        int availablePositions = board.getEmptyPositions().size();
        assertTrue(initAvailablePositions > availablePositions);
    }

    @Test
    public void givenEmptyBoard_whenSimulateInterAIPlay_thenGameDraw() {
        Board board = new Board();

        int player = Board.P1;
        int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
        for (int i = 0; i < totalMoves; i++) {
            board = mcts.findNextMove(board, player);
            board.printBoard();
            if (board.checkStatus() != -1) {
                break;
            }
            player = 3 - player;
        }
        int winStatus = board.checkStatus();
        assertEquals(winStatus, Board.DRAW);
    }

    @Test
    public void givenEmptyBoard_whenLevel1VsLevel3_thenLevel3WinsOrDraw() {
        Board board = new Board();
        MonteCarloTreeSearch mcts1 = new MonteCarloTreeSearch();
        mcts1.setLevel(1);
        MonteCarloTreeSearch mcts3 = new MonteCarloTreeSearch();
        mcts3.setLevel(3);

        int player = Board.P1;
        int totalMoves = Board.DEFAULT_BOARD_SIZE * Board.DEFAULT_BOARD_SIZE;
        for (int i = 0; i < totalMoves; i++) {
            if (player == Board.P1)
                board = mcts3.findNextMove(board, player);
            else
                board = mcts1.findNextMove(board, player);

            board.printBoard();
            if (board.checkStatus() != -1) {
                break;
            }
            player = 3 - player;
        }
        int winStatus = board.checkStatus();
        assertTrue(winStatus == Board.DRAW || winStatus == Board.P1);
    }

}

3、后记

本来我只是想知道蒙特卡洛树到底是怎样搜索的,想更了解一下 dfs 之类的,没想到说了一堆算法,好想还是理解不了。

4、参考资料

https://zhuanlan.zhihu.com/p/53948964
https://zhuanlan.zhihu.com/p/34990220

上一篇下一篇

猜你喜欢

热点阅读