00001
00008 #pragma once
00009
00010 #include <cmath>
00011 #include <queue>
00012 #include <set>
00013
00014 #include "hash.h"
00015
00016 using std::map;
00017 using std::queue;
00018 using std::sqrt;
00019 using std::set;
00020 using std::make_pair;
00021
00022 #define MAX_PLAYOUT_LENGTH 100 //these are 2 "moves" ( i.e. maximally 2 times 4 steps )
00023 #define UCT_MAX_DEPTH 30
00024 #define CHILDREN_CACHE_SIZE 5
00025 #define CCACHE_START_THRESHOLD 50
00026 #define EVAL_AFTER_LENGTH (cfg.playoutLen())
00027 #define FPU 0.9
00028
00029 #define NODE_VICTORY(node_type) (node_type == NODE_MAX ? 2 : -1 )
00030 #define WINNER_TO_VALUE(winner) (winner == GOLD ? 1 : -1 )
00031
00032 enum playoutStatus_e {PLAYOUT_OK, PLAYOUT_TOO_LONG, PLAYOUT_EVAL};
00033
00034 class Engine;
00035 class Uct;
00036
00043 enum nodeType_e {NODE_MAX, NODE_MIN};
00044
00045 #define PLAYER_TO_NODE_TYPE(player) (player == GOLD ? NODE_MAX : NODE_MIN)
00046
00053 class SimplePlayout
00054 {
00055 public:
00059 SimplePlayout(Board*, uint maxPlayoutLength, uint evalAfterLength);
00060
00068 playoutStatus_e doPlayout();
00069
00073 uint getPlayoutLength();
00074
00075 virtual ~SimplePlayout(){};
00076
00077 protected:
00083 virtual void playOne();
00084
00088 bool hasWinner();
00089
00090 SimplePlayout();
00091
00093 Board* board_;
00095 uint playoutLength_;
00097 uint maxPlayoutLength_;
00099 uint evalAfterLength_;
00100
00101 };
00102
00110 class AdvisorPlayout : public SimplePlayout
00111 {
00112 public:
00116 AdvisorPlayout(Board*, uint maxPlayoutLength, uint evalAfterLength,
00117 MoveAdvisor* advisor);
00118
00122 void playOne();
00123
00124 private:
00126 MoveAdvisor * advisor_;
00127 };
00128
00136 class TWstep
00137 {
00138 public:
00139 TWstep(Step, float, int);
00140 Step step;
00141 float value;
00142 int visits;
00143 } ;
00144
00145 typedef map<Step, TWstep> TWstepsMap;
00146
00153 class TWsteps: public TWstepsMap
00154 {
00155 public:
00156 TWstep& operator[](const Step& step);
00157 private:
00158 };
00159
00160 typedef set<Node*> NodeSet;
00161
00162
00169 class TTitem
00170 {
00171 public:
00172 TTitem(NodeList*);
00173
00174 NodeList* getNodes() const;
00175 private:
00176 TTitem();
00177
00178
00179 int visits_;
00180
00181 float value_;
00183 NodeList* nodes_;
00184
00185 friend class Node;
00186 };
00187
00191 class Node
00192 {
00193 public:
00194 Node();
00195
00199 Node(TWstep*, float heur=0);
00200
00204 Node* findUctChild(Node * realFather);
00205
00211 Node* findRandomChild() const;
00212
00216 Node* findMostExploredChild() const;
00217
00221 float exploreFormula(float) const;
00222
00228 void cCacheInit();
00229
00234 void cCacheUpdate(float exploreCoeff);
00235
00239 void uctOneChild(Node* act, Node* & best, float & bestUrgency, float exploreCoeff) const;
00240
00244 float ucb(float exploreCoeff) const;
00245
00249 float ucbTuned(float exploreCoeff) const;
00250
00254 void addChild(Node* child);
00255
00259 void reverseChildren();
00260
00264 void delChildrenRec();
00265
00275 void connectToMaster(const bool lock=true);
00276
00283 void connectChildrenToMaster();
00284
00290 void syncMaster();
00291
00297 void recSyncMaster();
00298
00304 void updateTTbrothers(float sample=0, int size=0);
00305
00311 void update(float);
00312
00318 void updateTWstep(float);
00319
00328 bool isMature() const;
00329
00330 bool hasChildren() const;
00331
00332
00333
00334 Node* getFather() const;
00335 void setFather(Node*);
00336 Node* getFirstChild() const;
00337 void setFirstChild(Node *);
00338 Node* getSibling() const;
00339 void setSibling(Node*);
00340 TTitem* getTTitem() const;
00341 void setTTitem(TTitem * node);
00342 Step getStep() const;
00343 TWstep* getTWstep() const;
00344 player_t getPlayer() const;
00345 int getVisits() const;
00346 void setVisits(int visits);
00347 float getValue() const;
00348 void setValue(float value);
00349 void setMaster(Node* master);
00350 Node* getMaster();
00351 void lock();
00352 void unlock();
00353 nodeType_e getNodeType() const;
00354
00358 int getDepth() const;
00359
00363 int getLocalDepth() const;
00364
00368 int getLevel() const;
00369
00373 int getDepthIdentifier() const;
00374
00378 string toString() const;
00379
00383 string recToString(int) const;
00384
00385 private:
00387 float value_;
00389 float heur_;
00391 float squareSum_;
00393 int visits_;
00395 TWstep* twStep_;
00396
00398 float masterValue_;
00399
00401 int masterVisits_;
00402
00404 TTitem* ttItem_;
00405
00406 Node* sibling_;
00407 Node* firstChild_;
00408 Node* father_;
00410 int cCacheLastUpdate_;
00412 Node** cCache_;
00414 Node* master_;
00415
00416
00417 pthread_mutex_t mutex;
00418 };
00419
00423 class Tree
00424 {
00425
00426 public:
00432 Tree(Node * root);
00433
00439 Tree(player_t firstPlayer);
00440
00446 ~Tree();
00447
00451 void syncMaster();
00452
00462 void expandNode(Node* node, const StepArray& steps, uint len, const HeurArray* heurs=NULL);
00463
00464
00465
00466
00467
00468
00469
00470 void expandNodeLimited(Node* node, const Move& move);
00471
00478 void uctDescend();
00479
00485 void randomDescend();
00486
00492 void firstChildDescend();
00493
00501 Node* findBestMoveNode(Node* subTreeRoot);
00502
00509 Move findBestMove(Node* bestMoveNode);
00510
00514 void updateHistory(float);
00515
00521 void historyReset();
00522
00526 Node* root();
00527
00531 Node* actNode();
00532
00536 int getNodesNum();
00537
00541 int getNodesPrunedNum();
00542
00546 int getNodesExpandedNum();
00547
00551 string toString();
00552
00559 string pathToActToString(bool onlyLastMove = false);
00560
00570 void updateTT(Node* father, const Board* board);
00571
00572
00573 private:
00574 friend class Uct;
00580 Tree();
00581
00585 void init();
00586
00595 static int calcNodeLevel(Node* father, const Step& step);
00596
00600 Node* history[2 * UCT_MAX_DEPTH];
00602 uint historyTop;
00604 int nodesExpandedNum_;
00606 int nodesNum_;
00608 int nodesPrunedNum_;
00613 TWsteps twSteps_;
00615 TT* tt_;
00616
00617 };
00618
00622 class Uct
00623 {
00624 public:
00625
00629 Uct(const Board* board);
00630
00637 Uct(const Board* board, const Uct* masterUct);
00638
00639 ~Uct();
00640
00647 void updateStatistics(Uct* ucts[], int uctsNum);
00648
00654 void searchTree(const Board*, const Engine*);
00655
00659 void refineResults(const Board* board);
00660
00667 void doPlayout(const Board* board);
00668
00672 string getStats(float seconds) const;
00673
00677 string getAdditionalInfo() const;
00678
00682 string getBestMoveRepr() const;
00683
00687 int getBestMoveVisits() const;
00688
00692 float getBestMoveValue() const;
00693
00699 float getWinRatio() const;
00700
00704 Tree* getTree() const;
00705
00709 int getPlayoutsNum() const;
00710
00711 private:
00715 Uct();
00716
00720 void init(const Board* board);
00721
00728 double decidePlayoutWinner(const Board*) const;
00729
00733 void fill_advisor(const Board * playBoard);
00734
00736 Tree* tree_;
00738 Eval* eval_;
00740 Node* bestMoveNode_;
00742 string bestMoveRepr_;
00744 int playouts_;
00746 int uctDescends_;
00747
00748 MoveAdvisor * advisor_;
00749 };
00750
00751