From: Alexander Schmidt Date: Mon, 2 Nov 2020 12:22:20 +0000 (+0100) Subject: added qtable X-Git-Url: http://git.treefish.org/~alex/shutbox.git/commitdiff_plain/86ceab763d09764aa693f896271eaefaf38dd9a9?ds=inline added qtable --- diff --git a/src/game.py b/src/game.py index 5e2fe70..5fc133a 100644 --- a/src/game.py +++ b/src/game.py @@ -10,15 +10,19 @@ class Game: def dice(self): if not self._diced: self._diced = [random.randint(1, 6), random.randint(1, 6)] + self._diced.sort() for rods in [ self._diced, [ abs(self._diced[0] - self._diced[1]) ], [ self._diced[0] + self._diced[1] ] ]: if self._can_be_shut(rods): self._options.append(rods) - def get_dice(self): + def get_diced(self): return self._diced + def get_shutable(self): + return self._shutable + def get_options(self): return self._options diff --git a/src/qtable.py b/src/qtable.py new file mode 100755 index 0000000..0ea19ec --- /dev/null +++ b/src/qtable.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + +import numpy as np +import random +import sys + +from game import Game + +states_dim = 147456 # 2^12 * 6^2 +actions_dim = 637 # 12+1 * (6+1)^2 +num_episodes = 1000 + +def find_state_qid(shutable, diced): + qid = 0 + for rod in shutable: + qid += pow(2, rod-1) + for i in range(len(diced)): + qid += (diced[i]-1) * pow(6, i) * pow(2, 12) + return qid + +def find_option_qid(option): + qid = 0 + for i in range(len(option)): + qid += option[i] * pow(7, i) * pow(13, len(option)-1) + return qid + +def select_option(opts, qs): + opt_qid_pairs = [] + opt_qsum = 0.0 + for opt in opts: + opt_qid = find_option_qid(opt) + opt_qid_pairs.append( [opt, opt_qid] ) + opt_qsum += qs[opt_qid] + random.shuffle(opt_qid_pairs) + ran_pt = random.uniform(0.0, opt_qsum) + decision_pt = 0.0 + for opt_qid_pair in opt_qid_pairs: + decision_pt += qs[ opt_qid_pair[1] ] + if ran_pt <= decision_pt: + return (opt_qid_pair[0], opt_qid_pair[1]) + return (None, None) + +Q = np.zeros([states_dim, actions_dim]) + +for i in range(num_episodes): + g = Game() + while not g.is_over(): + g.dice() + state_qid = find_state_qid(g.get_shutable(), g.get_diced()) + opt, opt_qid = select_option( g.get_options(), Q[state_qid, :] ) + if opt: + g.shut(opt) + print( "%d: %d" % (i, g.get_score()) )