X-Git-Url: http://git.treefish.org/~alex/shutbox.git/blobdiff_plain/9e3585ba3617d2af399e6a18bfd45b046f7c5e1e..HEAD:/src/qtable.py diff --git a/src/qtable.py b/src/qtable.py index 5bb1c51..29a3081 100755 --- a/src/qtable.py +++ b/src/qtable.py @@ -6,7 +6,7 @@ import sys from game import Game -learning_rate = 1.0 +learning_rate = 0.001 discount_factor = 1.0 states_dim = 36864 # 2^10 * 6^2 @@ -40,7 +40,6 @@ def select_option(opts, qs): decision_pt += qs[ opt_qid_pair[1] ] if ran_pt <= decision_pt: return (opt_qid_pair[0], opt_qid_pair[1]) - return (None, None) Q = np.ones([states_dim, actions_dim]) @@ -51,19 +50,20 @@ for i in range(num_episodes): g.dice() state_qid = find_state_qid(g.get_shutable(), g.get_diced()) while not g.is_over(): - opt, opt_qid = select_option( g.get_options(), Q[state_qid, :] ) - if opt: + options = g.get_options() + if len(options) > 0: + opt, opt_qid = select_option( options, Q[state_qid, :] ) old_score = g.get_score() g.shut(opt) g.dice() - reward = (g.get_score() - old_score) / 11.0 + reward = g.get_score() - old_score new_state_qid = find_state_qid(g.get_shutable(), g.get_diced()) Q[state_qid, opt_qid] += \ learning_rate * (reward + discount_factor * np.max(Q[new_state_qid, :]) - Q[state_qid, opt_qid]) state_qid = new_state_qid - Q[state_qid, :] = 0 + Q[state_qid, :] = 0.0 running_score[0] *= 0.99999999 running_score[0] += g.get_score() running_score[1] *= 0.99999999