]> git.treefish.org Git - shutbox.git/blob - src/1on1.py
refactoring
[shutbox.git] / src / 1on1.py
1 #!/usr/bin/env python3
2
3 import numpy as np
4 import random
5 import sys
6
7 from game import Game
8
9 learning_rate = 0.1
10 discount_factor = 1.0
11
12 states_dim = 36864 # 2^10 * 6^2
13 actions_dim = 539 # 10+1 * (6+1)^2
14 num_episodes = 10000000000
15
16 def find_state_qid(shutable, diced):
17     qid = 0
18     for rod in shutable:
19         qid += pow(2, rod-1)
20     for i in range(len(diced)):
21         qid += (diced[i]-1) * pow(6, i) * pow(2, 10)
22     return qid
23
24 def find_option_qid(option):
25     qid = 0
26     for i in range(len(option)):
27         qid += option[i] * pow(7, i) * pow(11, len(option)-1)
28     return qid
29
30 def select_option(opts, qs):
31     opt_qid_pairs = []
32     opt_qsum = 0.0
33     for opt in opts:
34         opt_qid = find_option_qid(opt)
35         opt_qid_pairs.append( [opt, opt_qid] )
36         opt_qsum += qs[opt_qid]
37     ran_pt = random.uniform(0.0, opt_qsum)
38     decision_pt = 0.0
39     for opt_qid_pair in opt_qid_pairs:
40         decision_pt += qs[ opt_qid_pair[1] ]
41         if ran_pt <= decision_pt:
42             return (opt_qid_pair[0], opt_qid_pair[1])
43
44 Q = np.ones([states_dim, actions_dim])
45
46 running_score = [0.0, 0.0]
47 stats = [1.0, 1.0]
48
49 for i in range(num_episodes):
50     g = Game()
51     g.dice()
52     state_qid = find_state_qid(g.get_shutable(), g.get_diced())
53     num_turn = random.randint(0, 1)
54     while not g.is_over():
55         options = g.get_options()
56         if len(options) > 0:
57             if num_turn % 2 == 0:
58                 opt, opt_qid = select_option( options, Q[state_qid, :] )
59                 g.shut(opt)
60                 g.dice()
61                 reward = 0.0
62                 new_state_qid = find_state_qid(g.get_shutable(), g.get_diced())
63                 Q[state_qid, opt_qid] += \
64                     learning_rate * (reward
65                                      + discount_factor * np.max(Q[new_state_qid, :])
66                                      - Q[state_qid, opt_qid])
67                 state_qid = new_state_qid
68             else:
69                 choice = random.randint(0, len(options) - 1)
70                 g.shut(options[choice])
71                 g.dice()
72                 state_qid = find_state_qid(g.get_shutable(), g.get_diced())
73             num_turn += 1
74     if num_turn % 2 == 0:
75         stats[1] += 1.0
76         Q[state_qid, :] = 0.0
77     else:
78         stats[0] += 1.0
79     print(stats[0]/stats[1])