]> git.treefish.org Git - shutbox.git/blobdiff - src/qtable.py
refactoring
[shutbox.git] / src / qtable.py
index 0ea19ecba7b0e88ae49b89e7cb46e83c3dd2ee25..29a3081235bfa8675657cdc6d0b866158ee88f7a 100755 (executable)
@@ -6,22 +6,25 @@ import sys
 
 from game import Game
 
 
 from game import Game
 
-states_dim = 147456 # 2^12 * 6^2
-actions_dim = 637 # 12+1 * (6+1)^2
-num_episodes = 1000
+learning_rate = 0.001
+discount_factor = 1.0
+
+states_dim = 36864 # 2^10 * 6^2
+actions_dim = 539 # 10+1 * (6+1)^2
+num_episodes = 10000000000
 
 def find_state_qid(shutable, diced):
     qid = 0
     for rod in shutable:
         qid += pow(2, rod-1)
     for i in range(len(diced)):
 
 def find_state_qid(shutable, diced):
     qid = 0
     for rod in shutable:
         qid += pow(2, rod-1)
     for i in range(len(diced)):
-        qid += (diced[i]-1) * pow(6, i) * pow(2, 12)
+        qid += (diced[i]-1) * pow(6, i) * pow(2, 10)
     return qid
 
 def find_option_qid(option):
     qid = 0
     for i in range(len(option)):
     return qid
 
 def find_option_qid(option):
     qid = 0
     for i in range(len(option)):
-        qid += option[i] * pow(7, i) * pow(13, len(option)-1)
+        qid += option[i] * pow(7, i) * pow(11, len(option)-1)
     return qid
 
 def select_option(opts, qs):
     return qid
 
 def select_option(opts, qs):
@@ -31,23 +34,38 @@ def select_option(opts, qs):
         opt_qid = find_option_qid(opt)
         opt_qid_pairs.append( [opt, opt_qid] )
         opt_qsum += qs[opt_qid]
         opt_qid = find_option_qid(opt)
         opt_qid_pairs.append( [opt, opt_qid] )
         opt_qsum += qs[opt_qid]
-    random.shuffle(opt_qid_pairs)
     ran_pt = random.uniform(0.0, opt_qsum)
     decision_pt = 0.0
     for opt_qid_pair in opt_qid_pairs:
         decision_pt += qs[ opt_qid_pair[1] ]
         if ran_pt <= decision_pt:
             return (opt_qid_pair[0], opt_qid_pair[1])
     ran_pt = random.uniform(0.0, opt_qsum)
     decision_pt = 0.0
     for opt_qid_pair in opt_qid_pairs:
         decision_pt += qs[ opt_qid_pair[1] ]
         if ran_pt <= decision_pt:
             return (opt_qid_pair[0], opt_qid_pair[1])
-    return (None, None)
 
 
-Q = np.zeros([states_dim, actions_dim])
+Q = np.ones([states_dim, actions_dim])
+
+running_score = [0.0, 0.0]
 
 for i in range(num_episodes):
     g = Game()
 
 for i in range(num_episodes):
     g = Game()
+    g.dice()
+    state_qid = find_state_qid(g.get_shutable(), g.get_diced())
     while not g.is_over():
     while not g.is_over():
-        g.dice()
-        state_qid = find_state_qid(g.get_shutable(), g.get_diced())
-        opt, opt_qid = select_option( g.get_options(), Q[state_qid, :] )
-        if opt:
+        options = g.get_options()
+        if len(options) > 0:
+            opt, opt_qid = select_option( options, Q[state_qid, :] )
+            old_score = g.get_score()
             g.shut(opt)
             g.shut(opt)
-    print( "%d: %d" % (i, g.get_score()) )
+            g.dice()
+            reward = g.get_score() - old_score
+            new_state_qid = find_state_qid(g.get_shutable(), g.get_diced())
+            Q[state_qid, opt_qid] += \
+                learning_rate * (reward
+                                 + discount_factor * np.max(Q[new_state_qid, :])
+                                 - Q[state_qid, opt_qid])
+            state_qid = new_state_qid
+    Q[state_qid, :] = 0.0
+    running_score[0] *= 0.99999999
+    running_score[0] += g.get_score()
+    running_score[1] *= 0.99999999
+    running_score[1] += 1.0
+    print( "%d: %f" % (i, running_score[0]/running_score[1]) )