]> git.treefish.org Git - shutbox.git/blob - src/qtable.py
0ea19ecba7b0e88ae49b89e7cb46e83c3dd2ee25
[shutbox.git] / src / qtable.py
1 #!/usr/bin/env python3
2
3 import numpy as np
4 import random
5 import sys
6
7 from game import Game
8
9 states_dim = 147456 # 2^12 * 6^2
10 actions_dim = 637 # 12+1 * (6+1)^2
11 num_episodes = 1000
12
13 def find_state_qid(shutable, diced):
14     qid = 0
15     for rod in shutable:
16         qid += pow(2, rod-1)
17     for i in range(len(diced)):
18         qid += (diced[i]-1) * pow(6, i) * pow(2, 12)
19     return qid
20
21 def find_option_qid(option):
22     qid = 0
23     for i in range(len(option)):
24         qid += option[i] * pow(7, i) * pow(13, len(option)-1)
25     return qid
26
27 def select_option(opts, qs):
28     opt_qid_pairs = []
29     opt_qsum = 0.0
30     for opt in opts:
31         opt_qid = find_option_qid(opt)
32         opt_qid_pairs.append( [opt, opt_qid] )
33         opt_qsum += qs[opt_qid]
34     random.shuffle(opt_qid_pairs)
35     ran_pt = random.uniform(0.0, opt_qsum)
36     decision_pt = 0.0
37     for opt_qid_pair in opt_qid_pairs:
38         decision_pt += qs[ opt_qid_pair[1] ]
39         if ran_pt <= decision_pt:
40             return (opt_qid_pair[0], opt_qid_pair[1])
41     return (None, None)
42
43 Q = np.zeros([states_dim, actions_dim])
44
45 for i in range(num_episodes):
46     g = Game()
47     while not g.is_over():
48         g.dice()
49         state_qid = find_state_qid(g.get_shutable(), g.get_diced())
50         opt, opt_qid = select_option( g.get_options(), Q[state_qid, :] )
51         if opt:
52             g.shut(opt)
53     print( "%d: %d" % (i, g.get_score()) )